Module refinery.lib.scripts.vba.deobfuscation.simplify
VBA expression simplification and constant folding transforms.
Expand source code Browse git
"""
VBA expression simplification and constant folding transforms.
"""
from __future__ import annotations
import operator
from typing import Callable
from refinery.lib.scripts import Transformer
from refinery.lib.scripts.vba.deobfuscation._helpers import (
_is_literal,
_make_integer_literal,
_make_numeric_literal,
_make_string_literal,
_numeric_value,
_string_value,
)
from refinery.lib.scripts.vba.deobfuscation.builtins import VBA_BUILTIN_CONSTANTS
from refinery.lib.scripts.vba.model import (
VbaBinaryExpression,
VbaBooleanLiteral,
VbaCallExpression,
VbaConstDeclaration,
VbaForEachStatement,
VbaForStatement,
VbaFunctionDeclaration,
VbaIdentifier,
VbaLetStatement,
VbaModule,
VbaOnErrorAction,
VbaOnErrorStatement,
VbaParenExpression,
VbaSubDeclaration,
VbaUnaryExpression,
)
_BINARY_OPS: dict[str, Callable] = {
'+' : operator.add,
'-' : operator.sub,
'*' : operator.mul,
'/' : operator.truediv,
}
_INTEGER_OPS: dict[str, Callable] = {
'\\' : lambda a, b: int(a) // int(b),
'Mod': lambda a, b: int(a) % int(b),
}
def _is_chr_call(node: VbaCallExpression) -> int | None:
if (
isinstance(node.callee, VbaIdentifier)
and node.callee.name.lower() in ('chr', 'chrw', 'chr$', 'chrw$')
and len(node.arguments) == 1
and node.arguments[0] is not None
):
val = _numeric_value(node.arguments[0])
if val is not None and isinstance(val, int) and 0 <= val <= 0xFFFF:
return val
return None
def _is_asc_call(node: VbaCallExpression) -> str | None:
if (
isinstance(node.callee, VbaIdentifier)
and node.callee.name.lower() in ('asc', 'ascw')
and len(node.arguments) == 1
and node.arguments[0] is not None
):
val = _string_value(node.arguments[0])
if val is not None and len(val) >= 1:
return val[0]
return None
def _try_string_function(node: VbaCallExpression) -> str | None:
if not isinstance(node.callee, VbaIdentifier):
return None
name = node.callee.name.lower().rstrip('$')
args = [a for a in node.arguments if a is not None]
if name == 'mid' and len(args) in (2, 3):
s = _string_value(args[0])
start_val = _numeric_value(args[1])
if s is None or start_val is None or not isinstance(start_val, int):
return None
start_idx = start_val - 1
if start_idx < 0:
return None
if len(args) == 3:
length_val = _numeric_value(args[2])
if length_val is None or not isinstance(length_val, int):
return None
return s[start_idx:start_idx + length_val]
return s[start_idx:]
if name == 'left' and len(args) == 2:
s = _string_value(args[0])
n = _numeric_value(args[1])
if s is not None and isinstance(n, int):
return s[:n]
if name == 'right' and len(args) == 2:
s = _string_value(args[0])
n = _numeric_value(args[1])
if s is not None and isinstance(n, int):
return s[-n:] if n > 0 else ''
if name == 'strreverse' and len(args) == 1:
s = _string_value(args[0])
if s is not None:
return s[::-1]
if name == 'lcase' and len(args) == 1:
s = _string_value(args[0])
if s is not None:
return s.lower()
if name == 'ucase' and len(args) == 1:
s = _string_value(args[0])
if s is not None:
return s.upper()
if name == 'len' and len(args) == 1:
s = _string_value(args[0])
if s is not None:
return None
if name == 'string' and len(args) == 2:
n = _numeric_value(args[0])
c = _string_value(args[1])
if isinstance(n, int) and c is not None and len(c) >= 1:
return c[0] * n
if name == 'space' and len(args) == 1:
n = _numeric_value(args[0])
if isinstance(n, int) and 0 <= n <= 10000:
return ' ' * n
if name == 'cstr' and len(args) == 1:
s = _string_value(args[0])
if s is not None:
return s
if name == 'replace' and len(args) >= 3:
haystack = _string_value(args[0])
needle = _string_value(args[1])
insert = _string_value(args[2])
if haystack is not None and needle is not None and insert is not None and needle:
return haystack.replace(needle, insert)
return None
class VbaSimplifications(Transformer):
def __init__(self):
super().__init__()
self._assigned_names: set[str] = set()
self._oern_bodies: set[int] = set()
def visit(self, node):
if isinstance(node, VbaModule):
self._collect_context(node)
return super().visit(node)
def _collect_context(self, module: VbaModule):
self._assigned_names = set(VBA_BUILTIN_CONSTANTS)
self._oern_bodies = set()
for n in module.walk():
if isinstance(n, VbaLetStatement) and isinstance(n.target, VbaIdentifier):
self._assigned_names.add(n.target.name.lower())
elif isinstance(n, VbaConstDeclaration):
for d in n.declarators:
self._assigned_names.add(d.name.lower())
elif isinstance(n, (VbaForStatement, VbaForEachStatement)):
if isinstance(n.variable, VbaIdentifier):
self._assigned_names.add(n.variable.name.lower())
if isinstance(n, (VbaFunctionDeclaration, VbaSubDeclaration)):
if n.params:
for p in n.params:
self._assigned_names.add(p.name.lower())
if n.name:
self._assigned_names.add(n.name.lower())
if n.body and any(
isinstance(s, VbaOnErrorStatement)
and s.action is VbaOnErrorAction.RESUME_NEXT
for s in n.body
):
self._oern_bodies.add(id(n.body))
if module.body and any(
isinstance(s, VbaOnErrorStatement)
and s.action is VbaOnErrorAction.RESUME_NEXT
for s in module.body
):
self._oern_bodies.add(id(module.body))
def _is_oern_undefined(self, node) -> bool:
if not isinstance(node, VbaIdentifier):
return False
if node.name.lower() in self._assigned_names:
return False
parent = node.parent
while parent is not None:
if isinstance(parent, (VbaFunctionDeclaration, VbaSubDeclaration)):
return id(parent.body) in self._oern_bodies
if isinstance(parent, VbaModule):
return id(parent.body) in self._oern_bodies
parent = parent.parent
return False
def visit_VbaBinaryExpression(self, node: VbaBinaryExpression):
self.generic_visit(node)
if node.left is None or node.right is None:
return None
op = node.operator
if op in ('&', '+'):
if self._is_oern_undefined(node.left) and _string_value(node.right) is not None:
return node.right
if self._is_oern_undefined(node.right) and _string_value(node.left) is not None:
return node.left
left_str = _string_value(node.left)
right_str = _string_value(node.right)
if op in ('&', '+') and left_str is not None and right_str is not None:
return _make_string_literal(left_str + right_str)
if op in ('&', '+') and right_str is not None:
if (
isinstance(node.left, VbaBinaryExpression)
and node.left.operator in ('&', '+')
):
inner_right_str = _string_value(node.left.right)
if inner_right_str is not None:
node.left.right = _make_string_literal(inner_right_str + right_str)
node.left.right.parent = node.left
return node.left
if op in ('&', '+') and left_str is not None:
inner = node.right
while (
isinstance(inner, VbaBinaryExpression)
and inner.operator in ('&', '+')
and isinstance(inner.left, VbaBinaryExpression)
and inner.left.operator in ('&', '+')
):
inner = inner.left
if (
isinstance(inner, VbaBinaryExpression)
and inner.operator in ('&', '+')
):
inner_left_str = _string_value(inner.left)
if inner_left_str is not None:
inner.left = _make_string_literal(left_str + inner_left_str)
inner.left.parent = inner
return node.right
left_num = _numeric_value(node.left)
right_num = _numeric_value(node.right)
if left_num is not None and right_num is not None:
fn = _BINARY_OPS.get(op)
if fn is not None:
try:
result = fn(left_num, right_num)
except (ZeroDivisionError, ValueError, OverflowError):
return None
if isinstance(result, float) and (
result != result
or result == float('inf')
or result == float('-inf')
):
return None
return _make_numeric_literal(result)
fn = _INTEGER_OPS.get(op)
if fn is not None:
try:
result = fn(left_num, right_num)
except (ZeroDivisionError, ValueError, OverflowError):
return None
return _make_integer_literal(int(result))
if op == '^':
try:
result = left_num ** right_num
except (ZeroDivisionError, ValueError, OverflowError):
return None
return _make_numeric_literal(result)
return None
def visit_VbaCallExpression(self, node: VbaCallExpression):
self.generic_visit(node)
code_point = _is_chr_call(node)
if code_point is not None:
c = chr(code_point)
if c.isprintable():
return _make_string_literal(c)
return None
char = _is_asc_call(node)
if char is not None:
return _make_integer_literal(ord(char))
result = _try_string_function(node)
if result is not None:
return _make_string_literal(result)
if (
isinstance(node.callee, VbaIdentifier)
and node.callee.name.lower() == 'len'
and len(node.arguments) == 1
and node.arguments[0] is not None
):
s = _string_value(node.arguments[0])
if s is not None:
return _make_integer_literal(len(s))
return None
def visit_VbaIdentifier(self, node: VbaIdentifier):
value = VBA_BUILTIN_CONSTANTS.get(node.name.lower())
if value is None:
return None
return _make_integer_literal(value)
def visit_VbaParenExpression(self, node: VbaParenExpression):
self.generic_visit(node)
inner = node.expression
if inner is None:
return None
if _is_literal(inner):
return inner
return None
def visit_VbaUnaryExpression(self, node: VbaUnaryExpression):
self.generic_visit(node)
if node.operand is None:
return None
op = node.operator
if op == '-':
val = _numeric_value(node.operand)
if val is not None:
return _make_numeric_literal(-val)
if op == 'Not':
if isinstance(node.operand, VbaBooleanLiteral):
return VbaBooleanLiteral(value=not node.operand.value)
val = _numeric_value(node.operand)
if isinstance(val, int):
return _make_integer_literal(~val)
return None
Classes
class VbaSimplifications-
In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the
changedflag.Expand source code Browse git
class VbaSimplifications(Transformer): def __init__(self): super().__init__() self._assigned_names: set[str] = set() self._oern_bodies: set[int] = set() def visit(self, node): if isinstance(node, VbaModule): self._collect_context(node) return super().visit(node) def _collect_context(self, module: VbaModule): self._assigned_names = set(VBA_BUILTIN_CONSTANTS) self._oern_bodies = set() for n in module.walk(): if isinstance(n, VbaLetStatement) and isinstance(n.target, VbaIdentifier): self._assigned_names.add(n.target.name.lower()) elif isinstance(n, VbaConstDeclaration): for d in n.declarators: self._assigned_names.add(d.name.lower()) elif isinstance(n, (VbaForStatement, VbaForEachStatement)): if isinstance(n.variable, VbaIdentifier): self._assigned_names.add(n.variable.name.lower()) if isinstance(n, (VbaFunctionDeclaration, VbaSubDeclaration)): if n.params: for p in n.params: self._assigned_names.add(p.name.lower()) if n.name: self._assigned_names.add(n.name.lower()) if n.body and any( isinstance(s, VbaOnErrorStatement) and s.action is VbaOnErrorAction.RESUME_NEXT for s in n.body ): self._oern_bodies.add(id(n.body)) if module.body and any( isinstance(s, VbaOnErrorStatement) and s.action is VbaOnErrorAction.RESUME_NEXT for s in module.body ): self._oern_bodies.add(id(module.body)) def _is_oern_undefined(self, node) -> bool: if not isinstance(node, VbaIdentifier): return False if node.name.lower() in self._assigned_names: return False parent = node.parent while parent is not None: if isinstance(parent, (VbaFunctionDeclaration, VbaSubDeclaration)): return id(parent.body) in self._oern_bodies if isinstance(parent, VbaModule): return id(parent.body) in self._oern_bodies parent = parent.parent return False def visit_VbaBinaryExpression(self, node: VbaBinaryExpression): self.generic_visit(node) if node.left is None or node.right is None: return None op = node.operator if op in ('&', '+'): if self._is_oern_undefined(node.left) and _string_value(node.right) is not None: return node.right if self._is_oern_undefined(node.right) and _string_value(node.left) is not None: return node.left left_str = _string_value(node.left) right_str = _string_value(node.right) if op in ('&', '+') and left_str is not None and right_str is not None: return _make_string_literal(left_str + right_str) if op in ('&', '+') and right_str is not None: if ( isinstance(node.left, VbaBinaryExpression) and node.left.operator in ('&', '+') ): inner_right_str = _string_value(node.left.right) if inner_right_str is not None: node.left.right = _make_string_literal(inner_right_str + right_str) node.left.right.parent = node.left return node.left if op in ('&', '+') and left_str is not None: inner = node.right while ( isinstance(inner, VbaBinaryExpression) and inner.operator in ('&', '+') and isinstance(inner.left, VbaBinaryExpression) and inner.left.operator in ('&', '+') ): inner = inner.left if ( isinstance(inner, VbaBinaryExpression) and inner.operator in ('&', '+') ): inner_left_str = _string_value(inner.left) if inner_left_str is not None: inner.left = _make_string_literal(left_str + inner_left_str) inner.left.parent = inner return node.right left_num = _numeric_value(node.left) right_num = _numeric_value(node.right) if left_num is not None and right_num is not None: fn = _BINARY_OPS.get(op) if fn is not None: try: result = fn(left_num, right_num) except (ZeroDivisionError, ValueError, OverflowError): return None if isinstance(result, float) and ( result != result or result == float('inf') or result == float('-inf') ): return None return _make_numeric_literal(result) fn = _INTEGER_OPS.get(op) if fn is not None: try: result = fn(left_num, right_num) except (ZeroDivisionError, ValueError, OverflowError): return None return _make_integer_literal(int(result)) if op == '^': try: result = left_num ** right_num except (ZeroDivisionError, ValueError, OverflowError): return None return _make_numeric_literal(result) return None def visit_VbaCallExpression(self, node: VbaCallExpression): self.generic_visit(node) code_point = _is_chr_call(node) if code_point is not None: c = chr(code_point) if c.isprintable(): return _make_string_literal(c) return None char = _is_asc_call(node) if char is not None: return _make_integer_literal(ord(char)) result = _try_string_function(node) if result is not None: return _make_string_literal(result) if ( isinstance(node.callee, VbaIdentifier) and node.callee.name.lower() == 'len' and len(node.arguments) == 1 and node.arguments[0] is not None ): s = _string_value(node.arguments[0]) if s is not None: return _make_integer_literal(len(s)) return None def visit_VbaIdentifier(self, node: VbaIdentifier): value = VBA_BUILTIN_CONSTANTS.get(node.name.lower()) if value is None: return None return _make_integer_literal(value) def visit_VbaParenExpression(self, node: VbaParenExpression): self.generic_visit(node) inner = node.expression if inner is None: return None if _is_literal(inner): return inner return None def visit_VbaUnaryExpression(self, node: VbaUnaryExpression): self.generic_visit(node) if node.operand is None: return None op = node.operator if op == '-': val = _numeric_value(node.operand) if val is not None: return _make_numeric_literal(-val) if op == 'Not': if isinstance(node.operand, VbaBooleanLiteral): return VbaBooleanLiteral(value=not node.operand.value) val = _numeric_value(node.operand) if isinstance(val, int): return _make_integer_literal(~val) return NoneAncestors
Methods
def visit(self, node)-
Expand source code Browse git
def visit(self, node): if isinstance(node, VbaModule): self._collect_context(node) return super().visit(node) def visit_VbaBinaryExpression(self, node)-
Expand source code Browse git
def visit_VbaBinaryExpression(self, node: VbaBinaryExpression): self.generic_visit(node) if node.left is None or node.right is None: return None op = node.operator if op in ('&', '+'): if self._is_oern_undefined(node.left) and _string_value(node.right) is not None: return node.right if self._is_oern_undefined(node.right) and _string_value(node.left) is not None: return node.left left_str = _string_value(node.left) right_str = _string_value(node.right) if op in ('&', '+') and left_str is not None and right_str is not None: return _make_string_literal(left_str + right_str) if op in ('&', '+') and right_str is not None: if ( isinstance(node.left, VbaBinaryExpression) and node.left.operator in ('&', '+') ): inner_right_str = _string_value(node.left.right) if inner_right_str is not None: node.left.right = _make_string_literal(inner_right_str + right_str) node.left.right.parent = node.left return node.left if op in ('&', '+') and left_str is not None: inner = node.right while ( isinstance(inner, VbaBinaryExpression) and inner.operator in ('&', '+') and isinstance(inner.left, VbaBinaryExpression) and inner.left.operator in ('&', '+') ): inner = inner.left if ( isinstance(inner, VbaBinaryExpression) and inner.operator in ('&', '+') ): inner_left_str = _string_value(inner.left) if inner_left_str is not None: inner.left = _make_string_literal(left_str + inner_left_str) inner.left.parent = inner return node.right left_num = _numeric_value(node.left) right_num = _numeric_value(node.right) if left_num is not None and right_num is not None: fn = _BINARY_OPS.get(op) if fn is not None: try: result = fn(left_num, right_num) except (ZeroDivisionError, ValueError, OverflowError): return None if isinstance(result, float) and ( result != result or result == float('inf') or result == float('-inf') ): return None return _make_numeric_literal(result) fn = _INTEGER_OPS.get(op) if fn is not None: try: result = fn(left_num, right_num) except (ZeroDivisionError, ValueError, OverflowError): return None return _make_integer_literal(int(result)) if op == '^': try: result = left_num ** right_num except (ZeroDivisionError, ValueError, OverflowError): return None return _make_numeric_literal(result) return None def visit_VbaCallExpression(self, node)-
Expand source code Browse git
def visit_VbaCallExpression(self, node: VbaCallExpression): self.generic_visit(node) code_point = _is_chr_call(node) if code_point is not None: c = chr(code_point) if c.isprintable(): return _make_string_literal(c) return None char = _is_asc_call(node) if char is not None: return _make_integer_literal(ord(char)) result = _try_string_function(node) if result is not None: return _make_string_literal(result) if ( isinstance(node.callee, VbaIdentifier) and node.callee.name.lower() == 'len' and len(node.arguments) == 1 and node.arguments[0] is not None ): s = _string_value(node.arguments[0]) if s is not None: return _make_integer_literal(len(s)) return None def visit_VbaIdentifier(self, node)-
Expand source code Browse git
def visit_VbaIdentifier(self, node: VbaIdentifier): value = VBA_BUILTIN_CONSTANTS.get(node.name.lower()) if value is None: return None return _make_integer_literal(value) def visit_VbaParenExpression(self, node)-
Expand source code Browse git
def visit_VbaParenExpression(self, node: VbaParenExpression): self.generic_visit(node) inner = node.expression if inner is None: return None if _is_literal(inner): return inner return None def visit_VbaUnaryExpression(self, node)-
Expand source code Browse git
def visit_VbaUnaryExpression(self, node: VbaUnaryExpression): self.generic_visit(node) if node.operand is None: return None op = node.operator if op == '-': val = _numeric_value(node.operand) if val is not None: return _make_numeric_literal(-val) if op == 'Not': if isinstance(node.operand, VbaBooleanLiteral): return VbaBooleanLiteral(value=not node.operand.value) val = _numeric_value(node.operand) if isinstance(val, int): return _make_integer_literal(~val) return None