Module refinery.lib.scripts.vba.deobfuscation.emulator
Evaluate user-defined VBA functions called with constant arguments.
Expand source code Browse git
"""
Evaluate user-defined VBA functions called with constant arguments.
"""
from __future__ import annotations
import operator as _op
from refinery.lib.scripts import Transformer
from refinery.lib.scripts.vba.deobfuscation.helpers import (
apply_removals,
is_identifier_read,
is_literal,
is_nan_or_inf,
literal_value,
value_to_node,
)
from refinery.lib.scripts.vba.deobfuscation.names import (
Value,
dispatch_builtin,
)
from refinery.lib.scripts.vba.model import (
VbaBinaryExpression,
VbaCallExpression,
VbaConstDeclaration,
VbaDebugPrintStatement,
VbaDoLoopStatement,
VbaExitKind,
VbaExitStatement,
VbaExpressionStatement,
VbaForStatement,
VbaFunctionDeclaration,
VbaIdentifier,
VbaIfStatement,
VbaLetStatement,
VbaLoopConditionPosition,
VbaLoopConditionType,
VbaModule,
VbaOnErrorAction,
VbaOnErrorStatement,
VbaParenExpression,
VbaUnaryExpression,
VbaVariableDeclaration,
)
_NUMERIC_DIVMOD_OPS = {
'\\' : _op.floordiv,
'mod': _op.mod,
}
_NUMERIC_BINARY_OPS = {
'-' : _op.sub,
'*' : _op.mul,
'/' : _op.truediv,
'^' : _op.pow,
}
_BITWISE_BINARY_OPS = {
'xor': _op.xor,
'and': _op.and_,
'or' : _op.or_,
}
_COMPARISON_OPS = {
'=' : _op.eq,
'<>' : _op.ne,
'<' : _op.lt,
'>' : _op.gt,
'<=' : _op.le,
'>=' : _op.ge,
}
class _VbaInterpreterError(Exception):
pass
class _UnevaluableError(Exception):
"""
Raised for statements the interpreter cannot model, such as implicit calls with potential side
effects. Unlike `_VbaInterpreterError`, this is not suppressed by On Error Resume Next, because
skipping a side-effecting statement would silently lose behavior.
"""
pass
class _ExitFunctionSignal(Exception):
pass
class _VbaInterpreter:
def __init__(
self,
function_name: str,
max_iterations: int = 100_000,
max_string_len: int = 1_000_000,
):
self.function_name = function_name.lower()
self.max_iterations = max_iterations
self.max_string_len = max_string_len
self._env: dict[str, Value] = {}
self._iterations = 0
self._on_error_resume_next = False
def execute(self, body: list, bindings: dict[str, Value]) -> Value:
self._env = dict(bindings)
self._iterations = 0
self._on_error_resume_next = False
try:
self._exec_statements(body)
except _ExitFunctionSignal:
pass
return self._env.get(self.function_name)
def _exec_statements(self, stmts: list):
for stmt in stmts:
if self._on_error_resume_next:
try:
self._exec_statement(stmt)
except _VbaInterpreterError:
continue
else:
self._exec_statement(stmt)
def _exec_statement(self, stmt):
if isinstance(stmt, VbaOnErrorStatement):
self._on_error_resume_next = stmt.action is VbaOnErrorAction.RESUME_NEXT
return
if isinstance(stmt, VbaLetStatement):
return self._exec_let(stmt)
if isinstance(stmt, VbaConstDeclaration):
return self._exec_const(stmt)
if isinstance(stmt, VbaIfStatement):
return self._exec_if(stmt)
if isinstance(stmt, VbaForStatement):
return self._exec_for(stmt)
if isinstance(stmt, VbaDoLoopStatement):
return self._exec_do_loop(stmt)
if isinstance(stmt, VbaExitStatement):
if stmt.kind is VbaExitKind.FUNCTION:
raise _ExitFunctionSignal
raise _VbaInterpreterError
if isinstance(stmt, VbaExpressionStatement):
raise _UnevaluableError
if isinstance(stmt, VbaDebugPrintStatement):
return
if isinstance(stmt, VbaVariableDeclaration):
return
raise _UnevaluableError
def _exec_let(self, stmt: VbaLetStatement):
if not isinstance(stmt.target, VbaIdentifier):
raise _VbaInterpreterError
key = stmt.target.name.lower()
value = self._eval(stmt.value)
self._env[key] = value
def _exec_const(self, stmt: VbaConstDeclaration):
for d in stmt.declarators:
key = d.name.lower()
value = self._eval(d.value)
self._env[key] = value
def _exec_if(self, stmt: VbaIfStatement):
if self._eval(stmt.condition):
self._exec_statements(stmt.body)
return
for clause in stmt.elseif_clauses:
if self._eval(clause.condition):
self._exec_statements(clause.body)
return
if stmt.else_body:
self._exec_statements(stmt.else_body)
def _exec_for(self, stmt: VbaForStatement):
if not isinstance(stmt.variable, VbaIdentifier):
raise _VbaInterpreterError
key = stmt.variable.name.lower()
start = self._to_number(self._eval(stmt.start))
end = self._to_number(self._eval(stmt.end))
step = self._to_number(self._eval(stmt.step)) if stmt.step else 1
if step == 0:
raise _VbaInterpreterError
counter = start
while True:
self._tick()
if step > 0 and counter > end:
break
if step < 0 and counter < end:
break
self._env[key] = counter
self._exec_statements(stmt.body)
counter = counter + step
def _exec_do_loop(self, stmt: VbaDoLoopStatement):
check_before = stmt.condition_position is VbaLoopConditionPosition.PRE
is_until = stmt.condition_type is VbaLoopConditionType.UNTIL
while True:
self._tick()
if check_before and self._should_exit_loop(stmt.condition, is_until):
break
self._exec_statements(stmt.body)
if not check_before and self._should_exit_loop(stmt.condition, is_until):
break
def _should_exit_loop(self, condition, is_until: bool) -> bool:
if condition is None:
return False
return is_until is bool(self._eval(condition))
def _tick(self):
self._iterations += 1
if self._iterations > self.max_iterations:
raise _VbaInterpreterError
def _eval(self, expr) -> Value:
if expr is None:
return None
value = literal_value(expr)
if value is not None:
return value
if isinstance(expr, VbaIdentifier):
return self._env.get(expr.name.lower())
if isinstance(expr, VbaBinaryExpression):
return self._eval_binary(expr)
if isinstance(expr, VbaUnaryExpression):
return self._eval_unary(expr)
if isinstance(expr, VbaParenExpression):
return self._eval(expr.expression)
if isinstance(expr, VbaCallExpression):
return self._eval_call(expr)
raise _VbaInterpreterError
def _eval_binary(self, node: VbaBinaryExpression) -> Value:
left = self._eval(node.left)
right = self._eval(node.right)
op = node.operator.lower()
if op == '&':
return self._concat(left, right)
if op == '+':
if isinstance(left, str) and isinstance(right, str):
return self._concat(left, right)
return self._numeric_op(left, right, _op.add)
if fn := _NUMERIC_DIVMOD_OPS.get(op):
a, b = self._to_int(left), self._to_int(right)
if b == 0:
raise _VbaInterpreterError
return fn(a, b)
if fn := _NUMERIC_BINARY_OPS.get(op):
return self._numeric_op(left, right, fn)
if fn := _BITWISE_BINARY_OPS.get(op):
return fn(self._to_int(left), self._to_int(right))
if fn := _COMPARISON_OPS.get(op):
return self._compare(left, right, fn)
raise _VbaInterpreterError
def _eval_unary(self, node: VbaUnaryExpression) -> Value:
val = self._eval(node.operand)
op = node.operator
if op == '-':
n = self._to_number(val)
return -n
if op.lower() == 'not':
if isinstance(val, bool):
return not val
return ~self._to_int(val)
raise _VbaInterpreterError
def _eval_call(self, node: VbaCallExpression) -> Value:
if not isinstance(node.callee, VbaIdentifier):
raise _VbaInterpreterError
name = node.callee.name.lower()
args = [self._eval(a) for a in node.arguments if a is not None]
try:
matched, result = dispatch_builtin(name, args)
except (ValueError, OverflowError, TypeError, IndexError):
raise _VbaInterpreterError
if not matched:
raise _VbaInterpreterError
return result
def _concat(self, lhs: Value, rhs: Value) -> str:
a = str(lhs) if lhs is not None else ''
b = str(rhs) if rhs is not None else ''
result = a + b
if len(result) > self.max_string_len:
raise _VbaInterpreterError
return result
@staticmethod
def _to_number(v: Value) -> int | float:
if v is None:
return 0
if isinstance(v, bool):
return -1 if v else 0
if isinstance(v, (int, float)):
return v
if isinstance(v, str):
try:
return int(v)
except ValueError:
try:
return float(v)
except ValueError:
raise _VbaInterpreterError
raise _VbaInterpreterError
@staticmethod
def _to_int(v: Value) -> int:
result = _VbaInterpreter._to_number(v)
return result if isinstance(result, int) else int(result)
def _numeric_op(self, left: Value, right: Value, op) -> int | float:
a = self._to_number(left)
b = self._to_number(right)
try:
result = op(a, b)
except (ZeroDivisionError, ValueError, OverflowError, ArithmeticError):
raise _VbaInterpreterError
if is_nan_or_inf(result):
raise _VbaInterpreterError
return result
@staticmethod
def _compare(left: Value, right: Value, op) -> bool:
if isinstance(left, str) and isinstance(right, str):
return op(left.lower(), right.lower())
if isinstance(left, (int, float)) and isinstance(right, (int, float)):
return op(left, right)
raise _VbaInterpreterError
class VbaFunctionEvaluator(Transformer):
"""
Evaluate calls to user-defined VBA functions when all arguments are constants.
Replaces the call expression with the computed string or integer literal.
Removes function definitions once all their calls have been resolved.
"""
def __init__(
self,
max_iterations: int = 100_000,
max_string_len: int = 1_000_000,
):
super().__init__()
self.max_iterations = max_iterations
self.max_string_len = max_string_len
self._functions: dict[str, VbaFunctionDeclaration] = {}
self._call_counts: dict[str, int] = {}
self._replaced_counts: dict[str, int] = {}
self._visiting = False
self._inside_function: str | None = None
def visit(self, node):
if self._visiting:
return super().visit(node)
return self._evaluate_module(node)
def _evaluate_module(self, node):
self._visiting = True
try:
self._functions.clear()
self._call_counts.clear()
self._replaced_counts.clear()
self._collect_functions(node)
if not self._functions:
return None
super().visit(node)
self._remove_resolved_definitions()
return None
finally:
self._visiting = False
def _collect_functions(self, root):
for node in root.walk():
if isinstance(node, VbaFunctionDeclaration):
if not node.name:
continue
self._functions[node.name.lower()] = node
def visit_VbaFunctionDeclaration(self, node: VbaFunctionDeclaration):
key = node.name.lower() if node.name else None
old = self._inside_function
self._inside_function = key
self.generic_visit(node)
self._inside_function = old
return None
def visit_VbaCallExpression(self, node: VbaCallExpression):
self.generic_visit(node)
if not isinstance(node.callee, VbaIdentifier):
return None
key = node.callee.name.lower()
funcdef = self._functions.get(key)
if funcdef is None:
return None
self._call_counts[key] = self._call_counts.get(key, 0) + 1
args = self._extract_constant_args(node)
if args is None:
return None
return self._try_replace(key, funcdef, args)
def visit_VbaIdentifier(self, node: VbaIdentifier):
key = node.name.lower()
if key == self._inside_function:
return None
funcdef = self._functions.get(key)
if funcdef is None:
return None
if funcdef.params:
required = [p for p in funcdef.params if not p.is_optional and p.default is None]
if required:
return None
if not is_identifier_read(node):
return None
self._call_counts[key] = self._call_counts.get(key, 0) + 1
return self._try_replace(key, funcdef, [])
def _try_replace(
self,
key: str,
funcdef: VbaFunctionDeclaration,
args: list[Value],
):
bindings = self._bind_parameters(funcdef, args)
if bindings is None:
return None
result = self._try_evaluate(funcdef, bindings)
if result is None:
return None
replacement = value_to_node(result)
if replacement is None:
return None
self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
return replacement
def _try_evaluate(
self,
funcdef: VbaFunctionDeclaration,
bindings: dict[str, Value],
) -> Value:
interpreter = _VbaInterpreter(
function_name=funcdef.name,
max_iterations=self.max_iterations,
max_string_len=self.max_string_len,
)
try:
return interpreter.execute(funcdef.body, bindings)
except (_VbaInterpreterError, _UnevaluableError):
return None
@staticmethod
def _extract_constant_args(node: VbaCallExpression) -> list[Value] | None:
args: list[Value] = []
for arg in node.arguments:
if arg is None:
args.append(None)
continue
if not is_literal(arg):
return None
args.append(literal_value(arg))
return args
@staticmethod
def _bind_parameters(
funcdef: VbaFunctionDeclaration,
args: list[Value],
) -> dict[str, Value] | None:
bindings: dict[str, Value] = {}
for i, param in enumerate(funcdef.params):
key = param.name.lower()
if i < len(args):
bindings[key] = args[i]
elif param.is_optional and param.default is not None:
if is_literal(param.default):
bindings[key] = literal_value(param.default)
else:
return None
elif param.is_optional:
bindings[key] = None
else:
return None
return bindings
def _remove_resolved_definitions(self):
removals: list[tuple[int, list]] = []
for key, funcdef in self._functions.items():
call_count = self._call_counts.get(key, 0)
replaced_count = self._replaced_counts.get(key, 0)
if call_count == 0 or replaced_count < call_count:
continue
parent = funcdef.parent
if parent is None or not isinstance(parent, VbaModule):
continue
for k, stmt in enumerate(parent.body):
if stmt is funcdef:
removals.append((k, parent.body))
break
if apply_removals(removals):
self.mark_changed()
Classes
class VbaFunctionEvaluator (max_iterations=100000, max_string_len=1000000)-
Evaluate calls to user-defined VBA functions when all arguments are constants. Replaces the call expression with the computed string or integer literal. Removes function definitions once all their calls have been resolved.
Expand source code Browse git
class VbaFunctionEvaluator(Transformer): """ Evaluate calls to user-defined VBA functions when all arguments are constants. Replaces the call expression with the computed string or integer literal. Removes function definitions once all their calls have been resolved. """ def __init__( self, max_iterations: int = 100_000, max_string_len: int = 1_000_000, ): super().__init__() self.max_iterations = max_iterations self.max_string_len = max_string_len self._functions: dict[str, VbaFunctionDeclaration] = {} self._call_counts: dict[str, int] = {} self._replaced_counts: dict[str, int] = {} self._visiting = False self._inside_function: str | None = None def visit(self, node): if self._visiting: return super().visit(node) return self._evaluate_module(node) def _evaluate_module(self, node): self._visiting = True try: self._functions.clear() self._call_counts.clear() self._replaced_counts.clear() self._collect_functions(node) if not self._functions: return None super().visit(node) self._remove_resolved_definitions() return None finally: self._visiting = False def _collect_functions(self, root): for node in root.walk(): if isinstance(node, VbaFunctionDeclaration): if not node.name: continue self._functions[node.name.lower()] = node def visit_VbaFunctionDeclaration(self, node: VbaFunctionDeclaration): key = node.name.lower() if node.name else None old = self._inside_function self._inside_function = key self.generic_visit(node) self._inside_function = old return None def visit_VbaCallExpression(self, node: VbaCallExpression): self.generic_visit(node) if not isinstance(node.callee, VbaIdentifier): return None key = node.callee.name.lower() funcdef = self._functions.get(key) if funcdef is None: return None self._call_counts[key] = self._call_counts.get(key, 0) + 1 args = self._extract_constant_args(node) if args is None: return None return self._try_replace(key, funcdef, args) def visit_VbaIdentifier(self, node: VbaIdentifier): key = node.name.lower() if key == self._inside_function: return None funcdef = self._functions.get(key) if funcdef is None: return None if funcdef.params: required = [p for p in funcdef.params if not p.is_optional and p.default is None] if required: return None if not is_identifier_read(node): return None self._call_counts[key] = self._call_counts.get(key, 0) + 1 return self._try_replace(key, funcdef, []) def _try_replace( self, key: str, funcdef: VbaFunctionDeclaration, args: list[Value], ): bindings = self._bind_parameters(funcdef, args) if bindings is None: return None result = self._try_evaluate(funcdef, bindings) if result is None: return None replacement = value_to_node(result) if replacement is None: return None self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1 return replacement def _try_evaluate( self, funcdef: VbaFunctionDeclaration, bindings: dict[str, Value], ) -> Value: interpreter = _VbaInterpreter( function_name=funcdef.name, max_iterations=self.max_iterations, max_string_len=self.max_string_len, ) try: return interpreter.execute(funcdef.body, bindings) except (_VbaInterpreterError, _UnevaluableError): return None @staticmethod def _extract_constant_args(node: VbaCallExpression) -> list[Value] | None: args: list[Value] = [] for arg in node.arguments: if arg is None: args.append(None) continue if not is_literal(arg): return None args.append(literal_value(arg)) return args @staticmethod def _bind_parameters( funcdef: VbaFunctionDeclaration, args: list[Value], ) -> dict[str, Value] | None: bindings: dict[str, Value] = {} for i, param in enumerate(funcdef.params): key = param.name.lower() if i < len(args): bindings[key] = args[i] elif param.is_optional and param.default is not None: if is_literal(param.default): bindings[key] = literal_value(param.default) else: return None elif param.is_optional: bindings[key] = None else: return None return bindings def _remove_resolved_definitions(self): removals: list[tuple[int, list]] = [] for key, funcdef in self._functions.items(): call_count = self._call_counts.get(key, 0) replaced_count = self._replaced_counts.get(key, 0) if call_count == 0 or replaced_count < call_count: continue parent = funcdef.parent if parent is None or not isinstance(parent, VbaModule): continue for k, stmt in enumerate(parent.body): if stmt is funcdef: removals.append((k, parent.body)) break if apply_removals(removals): self.mark_changed()Ancestors
Methods
def visit(self, node)-
Expand source code Browse git
def visit(self, node): if self._visiting: return super().visit(node) return self._evaluate_module(node) def visit_VbaFunctionDeclaration(self, node)-
Expand source code Browse git
def visit_VbaFunctionDeclaration(self, node: VbaFunctionDeclaration): key = node.name.lower() if node.name else None old = self._inside_function self._inside_function = key self.generic_visit(node) self._inside_function = old return None def visit_VbaCallExpression(self, node)-
Expand source code Browse git
def visit_VbaCallExpression(self, node: VbaCallExpression): self.generic_visit(node) if not isinstance(node.callee, VbaIdentifier): return None key = node.callee.name.lower() funcdef = self._functions.get(key) if funcdef is None: return None self._call_counts[key] = self._call_counts.get(key, 0) + 1 args = self._extract_constant_args(node) if args is None: return None return self._try_replace(key, funcdef, args) def visit_VbaIdentifier(self, node)-
Expand source code Browse git
def visit_VbaIdentifier(self, node: VbaIdentifier): key = node.name.lower() if key == self._inside_function: return None funcdef = self._functions.get(key) if funcdef is None: return None if funcdef.params: required = [p for p in funcdef.params if not p.is_optional and p.default is None] if required: return None if not is_identifier_read(node): return None self._call_counts[key] = self._call_counts.get(key, 0) + 1 return self._try_replace(key, funcdef, [])