Module refinery.lib.scripts.vba.deobfuscation.helpers

Shared AST utilities for VBA deobfuscation transforms.

Expand source code Browse git
"""
Shared AST utilities for VBA deobfuscation transforms.
"""
from __future__ import annotations

from operator import itemgetter
from typing import Generator

from refinery.lib.scripts import Expression, Kind, Statement, _classify_fields
from refinery.lib.scripts.vba.deobfuscation.names import CHR_NAMES, CompareMode, Value
from refinery.lib.scripts.vba.model import (
    VbaBinaryExpression,
    VbaBooleanLiteral,
    VbaCallExpression,
    VbaConstDeclaration,
    VbaConstDeclarator,
    VbaEmptyLiteral,
    VbaExpressionStatement,
    VbaFloatLiteral,
    VbaForEachStatement,
    VbaForStatement,
    VbaIdentifier,
    VbaIntegerLiteral,
    VbaLetStatement,
    VbaModule,
    VbaOptionStatement,
    VbaStringLiteral,
    VbaUnaryExpression,
)

LITERAL_TYPES = (VbaStringLiteral, VbaIntegerLiteral, VbaFloatLiteral, VbaBooleanLiteral, VbaEmptyLiteral)


def make_string_literal(value: str) -> VbaStringLiteral:
    escaped = value.replace('"', '""')
    raw = F'"{escaped}"'
    return VbaStringLiteral(value=value, raw=raw)


def is_nan_or_inf(value) -> bool:
    return isinstance(value, float) and (value != value or abs(value) == float('inf'))


def module_compare_mode(module: VbaModule) -> CompareMode:
    """
    Return the module's `Option Compare` mode. VBA `Option Compare` is a module-level directive; the
    default (no directive, or `Binary`) is `CompareMode.BINARY`, `Text` is case-insensitive, and
    `Database` (Access) uses the database's locale-dependent sort order. The latter cannot be
    reproduced statically, so the folding transforms refuse to fold comparisons under it.
    """
    for stmt in module.body:
        if isinstance(stmt, VbaOptionStatement) and stmt.keyword.lower() == 'compare':
            value = stmt.value.lower()
            if value == 'text':
                return CompareMode.TEXT
            if value == 'database':
                return CompareMode.DATABASE
            return CompareMode.BINARY
    return CompareMode.BINARY


def vba_int_div(a: int | float, b: int | float) -> int:
    """
    VBA integer division (the `\\` operator): both operands are rounded to integers and the
    quotient is truncated toward zero, unlike Python `//` which floors. Raises `ZeroDivisionError`
    when the divisor rounds to zero.
    """
    a, b = round(a), round(b)
    if b == 0:
        raise ZeroDivisionError
    quotient = abs(a) // abs(b)
    return -quotient if (a < 0) != (b < 0) else quotient


def vba_mod(a: int | float, b: int | float) -> int:
    """
    VBA modulo (the `Mod` operator): both operands are rounded to integers and the remainder takes
    the sign of the dividend, unlike Python `%` whose result takes the sign of the divisor. Raises
    `ZeroDivisionError` when the divisor rounds to zero.
    """
    a, b = round(a), round(b)
    if b == 0:
        raise ZeroDivisionError
    remainder = abs(a) % abs(b)
    return -remainder if a < 0 else remainder


def make_integer_literal(value: int) -> VbaIntegerLiteral:
    return VbaIntegerLiteral(value=value, raw=str(value))


def make_float_literal(value: float) -> VbaFloatLiteral:
    return VbaFloatLiteral(value=value, raw=str(value))


def make_numeric_literal(value: int | float) -> VbaIntegerLiteral | VbaFloatLiteral:
    if isinstance(value, float):
        if value == int(value) and abs(value) < 2 ** 53:
            return make_integer_literal(int(value))
        return make_float_literal(value)
    return make_integer_literal(value)


def is_literal(node: Expression) -> bool:
    return isinstance(node, LITERAL_TYPES)


def is_constant_expr(node: Expression) -> bool:
    """
    Returns `True` for expressions that can be safely propagated as constants: literals, Chr/ChrW
    calls with literal integer arguments, and concatenations of such expressions.
    """
    if is_literal(node):
        return True
    if isinstance(node, VbaCallExpression):
        if (
            isinstance(node.callee, VbaIdentifier)
            and node.callee.name.lower() in CHR_NAMES
            and len(node.arguments) == 1
            and node.arguments[0] is not None
            and isinstance(node.arguments[0], VbaIntegerLiteral)
        ):
            return True
        return False
    if isinstance(node, VbaBinaryExpression):
        if node.operator in ('&', '+'):
            return (
                node.left is not None
                and node.right is not None
                and is_constant_expr(node.left)
                and is_constant_expr(node.right)
            )
    if isinstance(node, VbaUnaryExpression):
        if node.operator in ('-', 'Not') and node.operand is not None:
            return is_constant_expr(node.operand)
    return False


def is_identifier_read(node: VbaIdentifier) -> bool:
    """
    Return whether an identifier node is in a read position. Returns `False` for identifiers that
    appear as assignment targets, declaration names, call targets, or loop variables.
    """
    parent = node.parent
    if isinstance(parent, VbaLetStatement) and parent.target is node:
        return False
    if isinstance(parent, (VbaConstDeclaration, VbaConstDeclarator)):
        return False
    if isinstance(parent, VbaCallExpression) and parent.callee is node:
        return False
    if isinstance(parent, VbaExpressionStatement) and parent.expression is node:
        return False
    if (
        isinstance(parent, (VbaForStatement, VbaForEachStatement))
        and parent.variable is node
    ):
        return False
    return True


def literal_value(node: Expression) -> Value:
    if isinstance(node, LITERAL_TYPES):
        return node.value
    return None


def constant_args(arguments: list[Expression | None]) -> list[Value] | None:
    """
    Collect the constant values of a builtin call's arguments, preserving an omitted argument as
    `None`. Returns `None` if any argument is a non-literal expression that cannot be folded.
    """
    values: list[Value] = []
    for arg in arguments:
        if arg is None:
            values.append(None)
        elif is_literal(arg):
            values.append(literal_value(arg))
        else:
            return None
    return values


def string_value(node: Expression | None) -> str | None:
    if isinstance(node, VbaStringLiteral):
        return node.value
    return None


def numeric_value(node: Expression | None) -> int | float | None:
    if isinstance(node, VbaIntegerLiteral):
        return node.value
    if isinstance(node, VbaFloatLiteral):
        return node.value
    return None


def make_chr_call(code_point: int) -> VbaCallExpression:
    return VbaCallExpression(
        callee=VbaIdentifier(name='Chr'),
        arguments=[make_integer_literal(code_point)],
    )


def string_to_expr(value: str) -> Expression:
    """
    Convert a Python string to a VBA AST expression. Printable-only strings become a single string
    literal; strings with non-printable characters become concatenated expressions using Chr calls.
    """
    if not value:
        return make_string_literal('')
    if all(c.isprintable() for c in value):
        return make_string_literal(value)
    parts: list[Expression] = []
    run: list[str] = []
    for c in value:
        if c.isprintable():
            run.append(c)
        else:
            if run:
                parts.append(make_string_literal(''.join(run)))
                run.clear()
            parts.append(make_chr_call(ord(c)))
    if run:
        parts.append(make_string_literal(''.join(run)))
    result = parts[0]
    for part in parts[1:]:
        result = VbaBinaryExpression(left=result, operator='&', right=part)
    return result


def value_to_node(value: Value) -> Expression:
    if value is None:
        return VbaEmptyLiteral()
    if isinstance(value, bool):
        return VbaBooleanLiteral(value=value)
    if isinstance(value, str):
        return string_to_expr(value)
    if isinstance(value, int):
        return make_integer_literal(value)
    return make_numeric_literal(value)


def body_lists(module: VbaModule) -> Generator[list[Statement]]:
    """
    Yield every statement-list body reachable from the module.
    """
    for node in module.walk():
        for field_name, kind in _classify_fields(type(node)):
            if kind != Kind.ChildList:
                continue
            body = getattr(node, field_name)
            if body and isinstance(body[0], Statement):
                yield body


def apply_removals(removals: list[tuple[int, list[Statement]]]) -> bool:
    """
    Delete statements at the given (body, index) positions in reverse index order so that earlier
    deletions do not invalidate later indices. Returns whether any removals occurred.
    """
    if not removals:
        return False
    removals.sort(key=itemgetter(0), reverse=True)
    for pos, body in removals:
        del body[pos]
    return True

Functions

def make_string_literal(value)
Expand source code Browse git
def make_string_literal(value: str) -> VbaStringLiteral:
    escaped = value.replace('"', '""')
    raw = F'"{escaped}"'
    return VbaStringLiteral(value=value, raw=raw)
def is_nan_or_inf(value)
Expand source code Browse git
def is_nan_or_inf(value) -> bool:
    return isinstance(value, float) and (value != value or abs(value) == float('inf'))
def module_compare_mode(module)

Return the module's Option Compare mode. VBA Option Compare is a module-level directive; the default (no directive, or Binary) is CompareMode.BINARY, Text is case-insensitive, and Database (Access) uses the database's locale-dependent sort order. The latter cannot be reproduced statically, so the folding transforms refuse to fold comparisons under it.

Expand source code Browse git
def module_compare_mode(module: VbaModule) -> CompareMode:
    """
    Return the module's `Option Compare` mode. VBA `Option Compare` is a module-level directive; the
    default (no directive, or `Binary`) is `CompareMode.BINARY`, `Text` is case-insensitive, and
    `Database` (Access) uses the database's locale-dependent sort order. The latter cannot be
    reproduced statically, so the folding transforms refuse to fold comparisons under it.
    """
    for stmt in module.body:
        if isinstance(stmt, VbaOptionStatement) and stmt.keyword.lower() == 'compare':
            value = stmt.value.lower()
            if value == 'text':
                return CompareMode.TEXT
            if value == 'database':
                return CompareMode.DATABASE
            return CompareMode.BINARY
    return CompareMode.BINARY
def vba_int_div(a, b)

VBA integer division (the \ operator): both operands are rounded to integers and the quotient is truncated toward zero, unlike Python // which floors. Raises ZeroDivisionError when the divisor rounds to zero.

Expand source code Browse git
def vba_int_div(a: int | float, b: int | float) -> int:
    """
    VBA integer division (the `\\` operator): both operands are rounded to integers and the
    quotient is truncated toward zero, unlike Python `//` which floors. Raises `ZeroDivisionError`
    when the divisor rounds to zero.
    """
    a, b = round(a), round(b)
    if b == 0:
        raise ZeroDivisionError
    quotient = abs(a) // abs(b)
    return -quotient if (a < 0) != (b < 0) else quotient
def vba_mod(a, b)

VBA modulo (the Mod operator): both operands are rounded to integers and the remainder takes the sign of the dividend, unlike Python % whose result takes the sign of the divisor. Raises ZeroDivisionError when the divisor rounds to zero.

Expand source code Browse git
def vba_mod(a: int | float, b: int | float) -> int:
    """
    VBA modulo (the `Mod` operator): both operands are rounded to integers and the remainder takes
    the sign of the dividend, unlike Python `%` whose result takes the sign of the divisor. Raises
    `ZeroDivisionError` when the divisor rounds to zero.
    """
    a, b = round(a), round(b)
    if b == 0:
        raise ZeroDivisionError
    remainder = abs(a) % abs(b)
    return -remainder if a < 0 else remainder
def make_integer_literal(value)
Expand source code Browse git
def make_integer_literal(value: int) -> VbaIntegerLiteral:
    return VbaIntegerLiteral(value=value, raw=str(value))
def make_float_literal(value)
Expand source code Browse git
def make_float_literal(value: float) -> VbaFloatLiteral:
    return VbaFloatLiteral(value=value, raw=str(value))
def make_numeric_literal(value)
Expand source code Browse git
def make_numeric_literal(value: int | float) -> VbaIntegerLiteral | VbaFloatLiteral:
    if isinstance(value, float):
        if value == int(value) and abs(value) < 2 ** 53:
            return make_integer_literal(int(value))
        return make_float_literal(value)
    return make_integer_literal(value)
def is_literal(node)
Expand source code Browse git
def is_literal(node: Expression) -> bool:
    return isinstance(node, LITERAL_TYPES)
def is_constant_expr(node)

Returns True for expressions that can be safely propagated as constants: literals, Chr/ChrW calls with literal integer arguments, and concatenations of such expressions.

Expand source code Browse git
def is_constant_expr(node: Expression) -> bool:
    """
    Returns `True` for expressions that can be safely propagated as constants: literals, Chr/ChrW
    calls with literal integer arguments, and concatenations of such expressions.
    """
    if is_literal(node):
        return True
    if isinstance(node, VbaCallExpression):
        if (
            isinstance(node.callee, VbaIdentifier)
            and node.callee.name.lower() in CHR_NAMES
            and len(node.arguments) == 1
            and node.arguments[0] is not None
            and isinstance(node.arguments[0], VbaIntegerLiteral)
        ):
            return True
        return False
    if isinstance(node, VbaBinaryExpression):
        if node.operator in ('&', '+'):
            return (
                node.left is not None
                and node.right is not None
                and is_constant_expr(node.left)
                and is_constant_expr(node.right)
            )
    if isinstance(node, VbaUnaryExpression):
        if node.operator in ('-', 'Not') and node.operand is not None:
            return is_constant_expr(node.operand)
    return False
def is_identifier_read(node)

Return whether an identifier node is in a read position. Returns False for identifiers that appear as assignment targets, declaration names, call targets, or loop variables.

Expand source code Browse git
def is_identifier_read(node: VbaIdentifier) -> bool:
    """
    Return whether an identifier node is in a read position. Returns `False` for identifiers that
    appear as assignment targets, declaration names, call targets, or loop variables.
    """
    parent = node.parent
    if isinstance(parent, VbaLetStatement) and parent.target is node:
        return False
    if isinstance(parent, (VbaConstDeclaration, VbaConstDeclarator)):
        return False
    if isinstance(parent, VbaCallExpression) and parent.callee is node:
        return False
    if isinstance(parent, VbaExpressionStatement) and parent.expression is node:
        return False
    if (
        isinstance(parent, (VbaForStatement, VbaForEachStatement))
        and parent.variable is node
    ):
        return False
    return True
def literal_value(node)
Expand source code Browse git
def literal_value(node: Expression) -> Value:
    if isinstance(node, LITERAL_TYPES):
        return node.value
    return None
def constant_args(arguments)

Collect the constant values of a builtin call's arguments, preserving an omitted argument as None. Returns None if any argument is a non-literal expression that cannot be folded.

Expand source code Browse git
def constant_args(arguments: list[Expression | None]) -> list[Value] | None:
    """
    Collect the constant values of a builtin call's arguments, preserving an omitted argument as
    `None`. Returns `None` if any argument is a non-literal expression that cannot be folded.
    """
    values: list[Value] = []
    for arg in arguments:
        if arg is None:
            values.append(None)
        elif is_literal(arg):
            values.append(literal_value(arg))
        else:
            return None
    return values
def string_value(node)
Expand source code Browse git
def string_value(node: Expression | None) -> str | None:
    if isinstance(node, VbaStringLiteral):
        return node.value
    return None
def numeric_value(node)
Expand source code Browse git
def numeric_value(node: Expression | None) -> int | float | None:
    if isinstance(node, VbaIntegerLiteral):
        return node.value
    if isinstance(node, VbaFloatLiteral):
        return node.value
    return None
def make_chr_call(code_point)
Expand source code Browse git
def make_chr_call(code_point: int) -> VbaCallExpression:
    return VbaCallExpression(
        callee=VbaIdentifier(name='Chr'),
        arguments=[make_integer_literal(code_point)],
    )
def string_to_expr(value)

Convert a Python string to a VBA AST expression. Printable-only strings become a single string literal; strings with non-printable characters become concatenated expressions using Chr calls.

Expand source code Browse git
def string_to_expr(value: str) -> Expression:
    """
    Convert a Python string to a VBA AST expression. Printable-only strings become a single string
    literal; strings with non-printable characters become concatenated expressions using Chr calls.
    """
    if not value:
        return make_string_literal('')
    if all(c.isprintable() for c in value):
        return make_string_literal(value)
    parts: list[Expression] = []
    run: list[str] = []
    for c in value:
        if c.isprintable():
            run.append(c)
        else:
            if run:
                parts.append(make_string_literal(''.join(run)))
                run.clear()
            parts.append(make_chr_call(ord(c)))
    if run:
        parts.append(make_string_literal(''.join(run)))
    result = parts[0]
    for part in parts[1:]:
        result = VbaBinaryExpression(left=result, operator='&', right=part)
    return result
def value_to_node(value)
Expand source code Browse git
def value_to_node(value: Value) -> Expression:
    if value is None:
        return VbaEmptyLiteral()
    if isinstance(value, bool):
        return VbaBooleanLiteral(value=value)
    if isinstance(value, str):
        return string_to_expr(value)
    if isinstance(value, int):
        return make_integer_literal(value)
    return make_numeric_literal(value)
def body_lists(module)

Yield every statement-list body reachable from the module.

Expand source code Browse git
def body_lists(module: VbaModule) -> Generator[list[Statement]]:
    """
    Yield every statement-list body reachable from the module.
    """
    for node in module.walk():
        for field_name, kind in _classify_fields(type(node)):
            if kind != Kind.ChildList:
                continue
            body = getattr(node, field_name)
            if body and isinstance(body[0], Statement):
                yield body
def apply_removals(removals)

Delete statements at the given (body, index) positions in reverse index order so that earlier deletions do not invalidate later indices. Returns whether any removals occurred.

Expand source code Browse git
def apply_removals(removals: list[tuple[int, list[Statement]]]) -> bool:
    """
    Delete statements at the given (body, index) positions in reverse index order so that earlier
    deletions do not invalidate later indices. Returns whether any removals occurred.
    """
    if not removals:
        return False
    removals.sort(key=itemgetter(0), reverse=True)
    for pos, body in removals:
        del body[pos]
    return True