Module refinery.lib.scripts.vba.deobfuscation.helpers
Shared AST utilities for VBA deobfuscation transforms.
Expand source code Browse git
"""
Shared AST utilities for VBA deobfuscation transforms.
"""
from __future__ import annotations
from operator import itemgetter
from typing import Generator
from refinery.lib.scripts import Expression, Kind, Statement, _classify_fields
from refinery.lib.scripts.vba.deobfuscation.names import CHR_NAMES, CompareMode, Value
from refinery.lib.scripts.vba.model import (
VbaBinaryExpression,
VbaBooleanLiteral,
VbaCallExpression,
VbaConstDeclaration,
VbaConstDeclarator,
VbaEmptyLiteral,
VbaExpressionStatement,
VbaFloatLiteral,
VbaForEachStatement,
VbaForStatement,
VbaIdentifier,
VbaIntegerLiteral,
VbaLetStatement,
VbaModule,
VbaOptionStatement,
VbaStringLiteral,
VbaUnaryExpression,
)
LITERAL_TYPES = (VbaStringLiteral, VbaIntegerLiteral, VbaFloatLiteral, VbaBooleanLiteral, VbaEmptyLiteral)
def make_string_literal(value: str) -> VbaStringLiteral:
escaped = value.replace('"', '""')
raw = F'"{escaped}"'
return VbaStringLiteral(value=value, raw=raw)
def is_nan_or_inf(value) -> bool:
return isinstance(value, float) and (value != value or abs(value) == float('inf'))
def module_compare_mode(module: VbaModule) -> CompareMode:
"""
Return the module's `Option Compare` mode. VBA `Option Compare` is a module-level directive; the
default (no directive, or `Binary`) is `CompareMode.BINARY`, `Text` is case-insensitive, and
`Database` (Access) uses the database's locale-dependent sort order. The latter cannot be
reproduced statically, so the folding transforms refuse to fold comparisons under it.
"""
for stmt in module.body:
if isinstance(stmt, VbaOptionStatement) and stmt.keyword.lower() == 'compare':
value = stmt.value.lower()
if value == 'text':
return CompareMode.TEXT
if value == 'database':
return CompareMode.DATABASE
return CompareMode.BINARY
return CompareMode.BINARY
def vba_int_div(a: int | float, b: int | float) -> int:
"""
VBA integer division (the `\\` operator): both operands are rounded to integers and the
quotient is truncated toward zero, unlike Python `//` which floors. Raises `ZeroDivisionError`
when the divisor rounds to zero.
"""
a, b = round(a), round(b)
if b == 0:
raise ZeroDivisionError
quotient = abs(a) // abs(b)
return -quotient if (a < 0) != (b < 0) else quotient
def vba_mod(a: int | float, b: int | float) -> int:
"""
VBA modulo (the `Mod` operator): both operands are rounded to integers and the remainder takes
the sign of the dividend, unlike Python `%` whose result takes the sign of the divisor. Raises
`ZeroDivisionError` when the divisor rounds to zero.
"""
a, b = round(a), round(b)
if b == 0:
raise ZeroDivisionError
remainder = abs(a) % abs(b)
return -remainder if a < 0 else remainder
def make_integer_literal(value: int) -> VbaIntegerLiteral:
return VbaIntegerLiteral(value=value, raw=str(value))
def make_float_literal(value: float) -> VbaFloatLiteral:
return VbaFloatLiteral(value=value, raw=str(value))
def make_numeric_literal(value: int | float) -> VbaIntegerLiteral | VbaFloatLiteral:
if isinstance(value, float):
if value == int(value) and abs(value) < 2 ** 53:
return make_integer_literal(int(value))
return make_float_literal(value)
return make_integer_literal(value)
def is_literal(node: Expression) -> bool:
return isinstance(node, LITERAL_TYPES)
def is_constant_expr(node: Expression) -> bool:
"""
Returns `True` for expressions that can be safely propagated as constants: literals, Chr/ChrW
calls with literal integer arguments, and concatenations of such expressions.
"""
if is_literal(node):
return True
if isinstance(node, VbaCallExpression):
if (
isinstance(node.callee, VbaIdentifier)
and node.callee.name.lower() in CHR_NAMES
and len(node.arguments) == 1
and node.arguments[0] is not None
and isinstance(node.arguments[0], VbaIntegerLiteral)
):
return True
return False
if isinstance(node, VbaBinaryExpression):
if node.operator in ('&', '+'):
return (
node.left is not None
and node.right is not None
and is_constant_expr(node.left)
and is_constant_expr(node.right)
)
if isinstance(node, VbaUnaryExpression):
if node.operator in ('-', 'Not') and node.operand is not None:
return is_constant_expr(node.operand)
return False
def is_identifier_read(node: VbaIdentifier) -> bool:
"""
Return whether an identifier node is in a read position. Returns `False` for identifiers that
appear as assignment targets, declaration names, call targets, or loop variables.
"""
parent = node.parent
if isinstance(parent, VbaLetStatement) and parent.target is node:
return False
if isinstance(parent, (VbaConstDeclaration, VbaConstDeclarator)):
return False
if isinstance(parent, VbaCallExpression) and parent.callee is node:
return False
if isinstance(parent, VbaExpressionStatement) and parent.expression is node:
return False
if (
isinstance(parent, (VbaForStatement, VbaForEachStatement))
and parent.variable is node
):
return False
return True
def literal_value(node: Expression) -> Value:
if isinstance(node, LITERAL_TYPES):
return node.value
return None
def constant_args(arguments: list[Expression | None]) -> list[Value] | None:
"""
Collect the constant values of a builtin call's arguments, preserving an omitted argument as
`None`. Returns `None` if any argument is a non-literal expression that cannot be folded.
"""
values: list[Value] = []
for arg in arguments:
if arg is None:
values.append(None)
elif is_literal(arg):
values.append(literal_value(arg))
else:
return None
return values
def string_value(node: Expression | None) -> str | None:
if isinstance(node, VbaStringLiteral):
return node.value
return None
def numeric_value(node: Expression | None) -> int | float | None:
if isinstance(node, VbaIntegerLiteral):
return node.value
if isinstance(node, VbaFloatLiteral):
return node.value
return None
def make_chr_call(code_point: int) -> VbaCallExpression:
return VbaCallExpression(
callee=VbaIdentifier(name='Chr'),
arguments=[make_integer_literal(code_point)],
)
def string_to_expr(value: str) -> Expression:
"""
Convert a Python string to a VBA AST expression. Printable-only strings become a single string
literal; strings with non-printable characters become concatenated expressions using Chr calls.
"""
if not value:
return make_string_literal('')
if all(c.isprintable() for c in value):
return make_string_literal(value)
parts: list[Expression] = []
run: list[str] = []
for c in value:
if c.isprintable():
run.append(c)
else:
if run:
parts.append(make_string_literal(''.join(run)))
run.clear()
parts.append(make_chr_call(ord(c)))
if run:
parts.append(make_string_literal(''.join(run)))
result = parts[0]
for part in parts[1:]:
result = VbaBinaryExpression(left=result, operator='&', right=part)
return result
def value_to_node(value: Value) -> Expression:
if value is None:
return VbaEmptyLiteral()
if isinstance(value, bool):
return VbaBooleanLiteral(value=value)
if isinstance(value, str):
return string_to_expr(value)
if isinstance(value, int):
return make_integer_literal(value)
return make_numeric_literal(value)
def body_lists(module: VbaModule) -> Generator[list[Statement]]:
"""
Yield every statement-list body reachable from the module.
"""
for node in module.walk():
for field_name, kind in _classify_fields(type(node)):
if kind != Kind.ChildList:
continue
body = getattr(node, field_name)
if body and isinstance(body[0], Statement):
yield body
def apply_removals(removals: list[tuple[int, list[Statement]]]) -> bool:
"""
Delete statements at the given (body, index) positions in reverse index order so that earlier
deletions do not invalidate later indices. Returns whether any removals occurred.
"""
if not removals:
return False
removals.sort(key=itemgetter(0), reverse=True)
for pos, body in removals:
del body[pos]
return True
Functions
def make_string_literal(value)-
Expand source code Browse git
def make_string_literal(value: str) -> VbaStringLiteral: escaped = value.replace('"', '""') raw = F'"{escaped}"' return VbaStringLiteral(value=value, raw=raw) def is_nan_or_inf(value)-
Expand source code Browse git
def is_nan_or_inf(value) -> bool: return isinstance(value, float) and (value != value or abs(value) == float('inf')) def module_compare_mode(module)-
Return the module's
Option Comparemode. VBAOption Compareis a module-level directive; the default (no directive, orBinary) isCompareMode.BINARY,Textis case-insensitive, andDatabase(Access) uses the database's locale-dependent sort order. The latter cannot be reproduced statically, so the folding transforms refuse to fold comparisons under it.Expand source code Browse git
def module_compare_mode(module: VbaModule) -> CompareMode: """ Return the module's `Option Compare` mode. VBA `Option Compare` is a module-level directive; the default (no directive, or `Binary`) is `CompareMode.BINARY`, `Text` is case-insensitive, and `Database` (Access) uses the database's locale-dependent sort order. The latter cannot be reproduced statically, so the folding transforms refuse to fold comparisons under it. """ for stmt in module.body: if isinstance(stmt, VbaOptionStatement) and stmt.keyword.lower() == 'compare': value = stmt.value.lower() if value == 'text': return CompareMode.TEXT if value == 'database': return CompareMode.DATABASE return CompareMode.BINARY return CompareMode.BINARY def vba_int_div(a, b)-
VBA integer division (the
\operator): both operands are rounded to integers and the quotient is truncated toward zero, unlike Python//which floors. RaisesZeroDivisionErrorwhen the divisor rounds to zero.Expand source code Browse git
def vba_int_div(a: int | float, b: int | float) -> int: """ VBA integer division (the `\\` operator): both operands are rounded to integers and the quotient is truncated toward zero, unlike Python `//` which floors. Raises `ZeroDivisionError` when the divisor rounds to zero. """ a, b = round(a), round(b) if b == 0: raise ZeroDivisionError quotient = abs(a) // abs(b) return -quotient if (a < 0) != (b < 0) else quotient def vba_mod(a, b)-
VBA modulo (the
Modoperator): both operands are rounded to integers and the remainder takes the sign of the dividend, unlike Python%whose result takes the sign of the divisor. RaisesZeroDivisionErrorwhen the divisor rounds to zero.Expand source code Browse git
def vba_mod(a: int | float, b: int | float) -> int: """ VBA modulo (the `Mod` operator): both operands are rounded to integers and the remainder takes the sign of the dividend, unlike Python `%` whose result takes the sign of the divisor. Raises `ZeroDivisionError` when the divisor rounds to zero. """ a, b = round(a), round(b) if b == 0: raise ZeroDivisionError remainder = abs(a) % abs(b) return -remainder if a < 0 else remainder def make_integer_literal(value)-
Expand source code Browse git
def make_integer_literal(value: int) -> VbaIntegerLiteral: return VbaIntegerLiteral(value=value, raw=str(value)) def make_float_literal(value)-
Expand source code Browse git
def make_float_literal(value: float) -> VbaFloatLiteral: return VbaFloatLiteral(value=value, raw=str(value)) def make_numeric_literal(value)-
Expand source code Browse git
def make_numeric_literal(value: int | float) -> VbaIntegerLiteral | VbaFloatLiteral: if isinstance(value, float): if value == int(value) and abs(value) < 2 ** 53: return make_integer_literal(int(value)) return make_float_literal(value) return make_integer_literal(value) def is_literal(node)-
Expand source code Browse git
def is_literal(node: Expression) -> bool: return isinstance(node, LITERAL_TYPES) def is_constant_expr(node)-
Returns
Truefor expressions that can be safely propagated as constants: literals, Chr/ChrW calls with literal integer arguments, and concatenations of such expressions.Expand source code Browse git
def is_constant_expr(node: Expression) -> bool: """ Returns `True` for expressions that can be safely propagated as constants: literals, Chr/ChrW calls with literal integer arguments, and concatenations of such expressions. """ if is_literal(node): return True if isinstance(node, VbaCallExpression): if ( isinstance(node.callee, VbaIdentifier) and node.callee.name.lower() in CHR_NAMES and len(node.arguments) == 1 and node.arguments[0] is not None and isinstance(node.arguments[0], VbaIntegerLiteral) ): return True return False if isinstance(node, VbaBinaryExpression): if node.operator in ('&', '+'): return ( node.left is not None and node.right is not None and is_constant_expr(node.left) and is_constant_expr(node.right) ) if isinstance(node, VbaUnaryExpression): if node.operator in ('-', 'Not') and node.operand is not None: return is_constant_expr(node.operand) return False def is_identifier_read(node)-
Return whether an identifier node is in a read position. Returns
Falsefor identifiers that appear as assignment targets, declaration names, call targets, or loop variables.Expand source code Browse git
def is_identifier_read(node: VbaIdentifier) -> bool: """ Return whether an identifier node is in a read position. Returns `False` for identifiers that appear as assignment targets, declaration names, call targets, or loop variables. """ parent = node.parent if isinstance(parent, VbaLetStatement) and parent.target is node: return False if isinstance(parent, (VbaConstDeclaration, VbaConstDeclarator)): return False if isinstance(parent, VbaCallExpression) and parent.callee is node: return False if isinstance(parent, VbaExpressionStatement) and parent.expression is node: return False if ( isinstance(parent, (VbaForStatement, VbaForEachStatement)) and parent.variable is node ): return False return True def literal_value(node)-
Expand source code Browse git
def literal_value(node: Expression) -> Value: if isinstance(node, LITERAL_TYPES): return node.value return None def constant_args(arguments)-
Collect the constant values of a builtin call's arguments, preserving an omitted argument as
None. ReturnsNoneif any argument is a non-literal expression that cannot be folded.Expand source code Browse git
def constant_args(arguments: list[Expression | None]) -> list[Value] | None: """ Collect the constant values of a builtin call's arguments, preserving an omitted argument as `None`. Returns `None` if any argument is a non-literal expression that cannot be folded. """ values: list[Value] = [] for arg in arguments: if arg is None: values.append(None) elif is_literal(arg): values.append(literal_value(arg)) else: return None return values def string_value(node)-
Expand source code Browse git
def string_value(node: Expression | None) -> str | None: if isinstance(node, VbaStringLiteral): return node.value return None def numeric_value(node)-
Expand source code Browse git
def numeric_value(node: Expression | None) -> int | float | None: if isinstance(node, VbaIntegerLiteral): return node.value if isinstance(node, VbaFloatLiteral): return node.value return None def make_chr_call(code_point)-
Expand source code Browse git
def make_chr_call(code_point: int) -> VbaCallExpression: return VbaCallExpression( callee=VbaIdentifier(name='Chr'), arguments=[make_integer_literal(code_point)], ) def string_to_expr(value)-
Convert a Python string to a VBA AST expression. Printable-only strings become a single string literal; strings with non-printable characters become concatenated expressions using Chr calls.
Expand source code Browse git
def string_to_expr(value: str) -> Expression: """ Convert a Python string to a VBA AST expression. Printable-only strings become a single string literal; strings with non-printable characters become concatenated expressions using Chr calls. """ if not value: return make_string_literal('') if all(c.isprintable() for c in value): return make_string_literal(value) parts: list[Expression] = [] run: list[str] = [] for c in value: if c.isprintable(): run.append(c) else: if run: parts.append(make_string_literal(''.join(run))) run.clear() parts.append(make_chr_call(ord(c))) if run: parts.append(make_string_literal(''.join(run))) result = parts[0] for part in parts[1:]: result = VbaBinaryExpression(left=result, operator='&', right=part) return result def value_to_node(value)-
Expand source code Browse git
def value_to_node(value: Value) -> Expression: if value is None: return VbaEmptyLiteral() if isinstance(value, bool): return VbaBooleanLiteral(value=value) if isinstance(value, str): return string_to_expr(value) if isinstance(value, int): return make_integer_literal(value) return make_numeric_literal(value) def body_lists(module)-
Yield every statement-list body reachable from the module.
Expand source code Browse git
def body_lists(module: VbaModule) -> Generator[list[Statement]]: """ Yield every statement-list body reachable from the module. """ for node in module.walk(): for field_name, kind in _classify_fields(type(node)): if kind != Kind.ChildList: continue body = getattr(node, field_name) if body and isinstance(body[0], Statement): yield body def apply_removals(removals)-
Delete statements at the given (body, index) positions in reverse index order so that earlier deletions do not invalidate later indices. Returns whether any removals occurred.
Expand source code Browse git
def apply_removals(removals: list[tuple[int, list[Statement]]]) -> bool: """ Delete statements at the given (body, index) positions in reverse index order so that earlier deletions do not invalidate later indices. Returns whether any removals occurred. """ if not removals: return False removals.sort(key=itemgetter(0), reverse=True) for pos, body in removals: del body[pos] return True