Module refinery.lib.scripts.vba.deobfuscation.helpers

Shared AST utilities for VBA deobfuscation transforms.

Expand source code Browse git
"""
Shared AST utilities for VBA deobfuscation transforms.
"""
from __future__ import annotations

from operator import itemgetter
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Generator

from refinery.lib.scripts import Expression, Kind, Statement, _classify_fields
from refinery.lib.scripts.vba.deobfuscation.names import CHR_NAMES, Value
from refinery.lib.scripts.vba.model import (
    VbaBinaryExpression,
    VbaBooleanLiteral,
    VbaCallExpression,
    VbaConstDeclaration,
    VbaConstDeclarator,
    VbaExpressionStatement,
    VbaFloatLiteral,
    VbaForEachStatement,
    VbaForStatement,
    VbaIdentifier,
    VbaIntegerLiteral,
    VbaLetStatement,
    VbaModule,
    VbaStringLiteral,
    VbaUnaryExpression,
)

LITERAL_TYPES = (VbaStringLiteral, VbaIntegerLiteral, VbaFloatLiteral, VbaBooleanLiteral)


def make_string_literal(value: str) -> VbaStringLiteral:
    escaped = value.replace('"', '""')
    raw = F'"{escaped}"'
    return VbaStringLiteral(value=value, raw=raw)


def is_nan_or_inf(value) -> bool:
    return isinstance(value, float) and (value != value or abs(value) == float('inf'))


def make_integer_literal(value: int) -> VbaIntegerLiteral:
    return VbaIntegerLiteral(value=value, raw=str(value))


def make_float_literal(value: float) -> VbaFloatLiteral:
    return VbaFloatLiteral(value=value, raw=str(value))


def make_numeric_literal(value: int | float) -> VbaIntegerLiteral | VbaFloatLiteral:
    if isinstance(value, float):
        if value == int(value) and abs(value) < 2 ** 53:
            return make_integer_literal(int(value))
        return make_float_literal(value)
    return make_integer_literal(value)


def is_literal(node: Expression) -> bool:
    return isinstance(node, LITERAL_TYPES)


def is_constant_expr(node: Expression) -> bool:
    """
    Returns `True` for expressions that can be safely propagated as constants: literals, Chr/ChrW
    calls with literal integer arguments, and concatenations of such expressions.
    """
    if is_literal(node):
        return True
    if isinstance(node, VbaCallExpression):
        if (
            isinstance(node.callee, VbaIdentifier)
            and node.callee.name.lower() in CHR_NAMES
            and len(node.arguments) == 1
            and node.arguments[0] is not None
            and isinstance(node.arguments[0], VbaIntegerLiteral)
        ):
            return True
        return False
    if isinstance(node, VbaBinaryExpression):
        if node.operator in ('&', '+'):
            return (
                node.left is not None
                and node.right is not None
                and is_constant_expr(node.left)
                and is_constant_expr(node.right)
            )
    if isinstance(node, VbaUnaryExpression):
        if node.operator in ('-', 'Not') and node.operand is not None:
            return is_constant_expr(node.operand)
    return False


def is_identifier_read(node: VbaIdentifier) -> bool:
    """
    Return whether an identifier node is in a read position. Returns `False` for identifiers that
    appear as assignment targets, declaration names, call targets, or loop variables.
    """
    parent = node.parent
    if isinstance(parent, VbaLetStatement) and parent.target is node:
        return False
    if isinstance(parent, (VbaConstDeclaration, VbaConstDeclarator)):
        return False
    if isinstance(parent, VbaCallExpression) and parent.callee is node:
        return False
    if isinstance(parent, VbaExpressionStatement) and parent.expression is node:
        return False
    if (
        isinstance(parent, (VbaForStatement, VbaForEachStatement))
        and parent.variable is node
    ):
        return False
    return True


def literal_value(node: Expression) -> Value:
    if isinstance(node, LITERAL_TYPES):
        return node.value
    return None


def string_value(node: Expression | None) -> str | None:
    if isinstance(node, VbaStringLiteral):
        return node.value
    return None


def numeric_value(node: Expression | None) -> int | float | None:
    if isinstance(node, VbaIntegerLiteral):
        return node.value
    if isinstance(node, VbaFloatLiteral):
        return node.value
    return None


def make_chr_call(code_point: int) -> VbaCallExpression:
    return VbaCallExpression(
        callee=VbaIdentifier(name='Chr'),
        arguments=[make_integer_literal(code_point)],
    )


def string_to_expr(value: str) -> Expression:
    """
    Convert a Python string to a VBA AST expression. Printable-only strings become a single string
    literal; strings with non-printable characters become concatenated expressions using Chr calls.
    """
    if not value:
        return make_string_literal('')
    if all(c.isprintable() for c in value):
        return make_string_literal(value)
    parts: list[Expression] = []
    run: list[str] = []
    for c in value:
        if c.isprintable():
            run.append(c)
        else:
            if run:
                parts.append(make_string_literal(''.join(run)))
                run.clear()
            parts.append(make_chr_call(ord(c)))
    if run:
        parts.append(make_string_literal(''.join(run)))
    result = parts[0]
    for part in parts[1:]:
        result = VbaBinaryExpression(left=result, operator='&', right=part)
    return result


def value_to_node(value: Value) -> Expression | None:
    if isinstance(value, str):
        return string_to_expr(value)
    if isinstance(value, int) and not isinstance(value, bool):
        return make_integer_literal(value)
    if isinstance(value, float):
        return make_numeric_literal(value)
    return None


def body_lists(module: VbaModule) -> Generator[list[Statement]]:
    """
    Yield every statement-list body reachable from the module.
    """
    for node in module.walk():
        for field_name, kind in _classify_fields(type(node)):
            if kind != Kind.ChildList:
                continue
            body = getattr(node, field_name)
            if body and isinstance(body[0], Statement):
                yield body


def apply_removals(removals: list[tuple[int, list[Statement]]]) -> bool:
    """
    Delete statements at the given (body, index) positions in reverse index order so that earlier
    deletions do not invalidate later indices. Returns whether any removals occurred.
    """
    removals.sort(key=itemgetter(0), reverse=True)
    for pos, body in removals:
        del body[pos]
    return bool(removals)

Functions

def make_string_literal(value)
Expand source code Browse git
def make_string_literal(value: str) -> VbaStringLiteral:
    escaped = value.replace('"', '""')
    raw = F'"{escaped}"'
    return VbaStringLiteral(value=value, raw=raw)
def is_nan_or_inf(value)
Expand source code Browse git
def is_nan_or_inf(value) -> bool:
    return isinstance(value, float) and (value != value or abs(value) == float('inf'))
def make_integer_literal(value)
Expand source code Browse git
def make_integer_literal(value: int) -> VbaIntegerLiteral:
    return VbaIntegerLiteral(value=value, raw=str(value))
def make_float_literal(value)
Expand source code Browse git
def make_float_literal(value: float) -> VbaFloatLiteral:
    return VbaFloatLiteral(value=value, raw=str(value))
def make_numeric_literal(value)
Expand source code Browse git
def make_numeric_literal(value: int | float) -> VbaIntegerLiteral | VbaFloatLiteral:
    if isinstance(value, float):
        if value == int(value) and abs(value) < 2 ** 53:
            return make_integer_literal(int(value))
        return make_float_literal(value)
    return make_integer_literal(value)
def is_literal(node)
Expand source code Browse git
def is_literal(node: Expression) -> bool:
    return isinstance(node, LITERAL_TYPES)
def is_constant_expr(node)

Returns True for expressions that can be safely propagated as constants: literals, Chr/ChrW calls with literal integer arguments, and concatenations of such expressions.

Expand source code Browse git
def is_constant_expr(node: Expression) -> bool:
    """
    Returns `True` for expressions that can be safely propagated as constants: literals, Chr/ChrW
    calls with literal integer arguments, and concatenations of such expressions.
    """
    if is_literal(node):
        return True
    if isinstance(node, VbaCallExpression):
        if (
            isinstance(node.callee, VbaIdentifier)
            and node.callee.name.lower() in CHR_NAMES
            and len(node.arguments) == 1
            and node.arguments[0] is not None
            and isinstance(node.arguments[0], VbaIntegerLiteral)
        ):
            return True
        return False
    if isinstance(node, VbaBinaryExpression):
        if node.operator in ('&', '+'):
            return (
                node.left is not None
                and node.right is not None
                and is_constant_expr(node.left)
                and is_constant_expr(node.right)
            )
    if isinstance(node, VbaUnaryExpression):
        if node.operator in ('-', 'Not') and node.operand is not None:
            return is_constant_expr(node.operand)
    return False
def is_identifier_read(node)

Return whether an identifier node is in a read position. Returns False for identifiers that appear as assignment targets, declaration names, call targets, or loop variables.

Expand source code Browse git
def is_identifier_read(node: VbaIdentifier) -> bool:
    """
    Return whether an identifier node is in a read position. Returns `False` for identifiers that
    appear as assignment targets, declaration names, call targets, or loop variables.
    """
    parent = node.parent
    if isinstance(parent, VbaLetStatement) and parent.target is node:
        return False
    if isinstance(parent, (VbaConstDeclaration, VbaConstDeclarator)):
        return False
    if isinstance(parent, VbaCallExpression) and parent.callee is node:
        return False
    if isinstance(parent, VbaExpressionStatement) and parent.expression is node:
        return False
    if (
        isinstance(parent, (VbaForStatement, VbaForEachStatement))
        and parent.variable is node
    ):
        return False
    return True
def literal_value(node)
Expand source code Browse git
def literal_value(node: Expression) -> Value:
    if isinstance(node, LITERAL_TYPES):
        return node.value
    return None
def string_value(node)
Expand source code Browse git
def string_value(node: Expression | None) -> str | None:
    if isinstance(node, VbaStringLiteral):
        return node.value
    return None
def numeric_value(node)
Expand source code Browse git
def numeric_value(node: Expression | None) -> int | float | None:
    if isinstance(node, VbaIntegerLiteral):
        return node.value
    if isinstance(node, VbaFloatLiteral):
        return node.value
    return None
def make_chr_call(code_point)
Expand source code Browse git
def make_chr_call(code_point: int) -> VbaCallExpression:
    return VbaCallExpression(
        callee=VbaIdentifier(name='Chr'),
        arguments=[make_integer_literal(code_point)],
    )
def string_to_expr(value)

Convert a Python string to a VBA AST expression. Printable-only strings become a single string literal; strings with non-printable characters become concatenated expressions using Chr calls.

Expand source code Browse git
def string_to_expr(value: str) -> Expression:
    """
    Convert a Python string to a VBA AST expression. Printable-only strings become a single string
    literal; strings with non-printable characters become concatenated expressions using Chr calls.
    """
    if not value:
        return make_string_literal('')
    if all(c.isprintable() for c in value):
        return make_string_literal(value)
    parts: list[Expression] = []
    run: list[str] = []
    for c in value:
        if c.isprintable():
            run.append(c)
        else:
            if run:
                parts.append(make_string_literal(''.join(run)))
                run.clear()
            parts.append(make_chr_call(ord(c)))
    if run:
        parts.append(make_string_literal(''.join(run)))
    result = parts[0]
    for part in parts[1:]:
        result = VbaBinaryExpression(left=result, operator='&', right=part)
    return result
def value_to_node(value)
Expand source code Browse git
def value_to_node(value: Value) -> Expression | None:
    if isinstance(value, str):
        return string_to_expr(value)
    if isinstance(value, int) and not isinstance(value, bool):
        return make_integer_literal(value)
    if isinstance(value, float):
        return make_numeric_literal(value)
    return None
def body_lists(module)

Yield every statement-list body reachable from the module.

Expand source code Browse git
def body_lists(module: VbaModule) -> Generator[list[Statement]]:
    """
    Yield every statement-list body reachable from the module.
    """
    for node in module.walk():
        for field_name, kind in _classify_fields(type(node)):
            if kind != Kind.ChildList:
                continue
            body = getattr(node, field_name)
            if body and isinstance(body[0], Statement):
                yield body
def apply_removals(removals)

Delete statements at the given (body, index) positions in reverse index order so that earlier deletions do not invalidate later indices. Returns whether any removals occurred.

Expand source code Browse git
def apply_removals(removals: list[tuple[int, list[Statement]]]) -> bool:
    """
    Delete statements at the given (body, index) positions in reverse index order so that earlier
    deletions do not invalidate later indices. Returns whether any removals occurred.
    """
    removals.sort(key=itemgetter(0), reverse=True)
    for pos, body in removals:
        del body[pos]
    return bool(removals)