Module refinery.lib.scripts.vba.deobfuscation.emulator

Evaluate user-defined VBA functions called with constant arguments.

Expand source code Browse git
"""
Evaluate user-defined VBA functions called with constant arguments.
"""
from __future__ import annotations

from typing import Any, Callable

from refinery.lib.scripts import Transformer
from refinery.lib.scripts.vba.deobfuscation._helpers import (
    _is_literal,
    _literal_value,
    _Value,
    _value_to_node,
)
from refinery.lib.scripts.vba.model import (
    VbaBinaryExpression,
    VbaBooleanLiteral,
    VbaCallExpression,
    VbaConstDeclaration,
    VbaDebugPrintStatement,
    VbaDoLoopStatement,
    VbaExitKind,
    VbaExitStatement,
    VbaExpressionStatement,
    VbaFloatLiteral,
    VbaForStatement,
    VbaFunctionDeclaration,
    VbaIdentifier,
    VbaIfStatement,
    VbaIntegerLiteral,
    VbaLetStatement,
    VbaLoopConditionPosition,
    VbaLoopConditionType,
    VbaModule,
    VbaOnErrorAction,
    VbaOnErrorStatement,
    VbaParenExpression,
    VbaStringLiteral,
    VbaUnaryExpression,
    VbaVariableDeclaration,
)


class _VbaInterpreterError(Exception):
    pass


class _UnevaluableError(Exception):
    """
    Raised for statements the interpreter cannot model, such as implicit calls with potential side
    effects. Unlike `_VbaInterpreterError`, this is not suppressed by On Error Resume Next, because
    skipping a side-effecting statement would silently lose behavior.
    """
    pass


class _ExitFunctionSignal(Exception):
    pass


def _cast_to_int(value):
    as_flt = float(value)
    as_int = int(as_flt)
    if as_flt < 0 and as_flt != int(as_flt):
        as_int -= 1
    return as_int


class _VbaInterpreter:

    def __init__(
        self,
        function_name: str,
        max_iterations: int = 100_000,
        max_string_length: int = 1_000_000,
    ):
        self.function_name = function_name.lower()
        self.max_iterations = max_iterations
        self.max_string_length = max_string_length
        self._env: dict[str, _Value] = {}
        self._iterations = 0
        self._on_error_resume_next = False

    def execute(self, body: list, bindings: dict[str, _Value]) -> _Value:
        self._env = dict(bindings)
        self._iterations = 0
        self._on_error_resume_next = False
        try:
            self._exec_statements(body)
        except _ExitFunctionSignal:
            pass
        return self._env.get(self.function_name)

    def _exec_statements(self, stmts: list):
        for stmt in stmts:
            if self._on_error_resume_next:
                try:
                    self._exec_statement(stmt)
                except _VbaInterpreterError:
                    continue
            else:
                self._exec_statement(stmt)

    def _exec_statement(self, stmt):
        if isinstance(stmt, VbaOnErrorStatement):
            self._on_error_resume_next = stmt.action is VbaOnErrorAction.RESUME_NEXT
            return
        if isinstance(stmt, VbaLetStatement):
            return self._exec_let(stmt)
        if isinstance(stmt, VbaConstDeclaration):
            return self._exec_const(stmt)
        if isinstance(stmt, VbaIfStatement):
            return self._exec_if(stmt)
        if isinstance(stmt, VbaForStatement):
            return self._exec_for(stmt)
        if isinstance(stmt, VbaDoLoopStatement):
            return self._exec_do_loop(stmt)
        if isinstance(stmt, VbaExitStatement):
            if stmt.kind is VbaExitKind.FUNCTION:
                raise _ExitFunctionSignal
            raise _VbaInterpreterError
        if isinstance(stmt, VbaExpressionStatement):
            raise _UnevaluableError
        if isinstance(stmt, VbaDebugPrintStatement):
            return
        if isinstance(stmt, VbaVariableDeclaration):
            return
        raise _UnevaluableError

    def _exec_let(self, stmt: VbaLetStatement):
        if not isinstance(stmt.target, VbaIdentifier):
            raise _VbaInterpreterError
        key = stmt.target.name.lower()
        value = self._eval(stmt.value)
        self._env[key] = value

    def _exec_const(self, stmt: VbaConstDeclaration):
        for d in stmt.declarators:
            key = d.name.lower()
            value = self._eval(d.value)
            self._env[key] = value

    def _exec_if(self, stmt: VbaIfStatement):
        cond = self._eval(stmt.condition)
        if self._truthy(cond):
            self._exec_statements(stmt.body)
            return
        for clause in stmt.elseif_clauses:
            cond = self._eval(clause.condition)
            if self._truthy(cond):
                self._exec_statements(clause.body)
                return
        if stmt.else_body:
            self._exec_statements(stmt.else_body)

    def _exec_for(self, stmt: VbaForStatement):
        if not isinstance(stmt.variable, VbaIdentifier):
            raise _VbaInterpreterError
        key = stmt.variable.name.lower()
        start = self._to_number(self._eval(stmt.start))
        end = self._to_number(self._eval(stmt.end))
        step = self._to_number(self._eval(stmt.step)) if stmt.step else 1
        if step == 0:
            raise _VbaInterpreterError
        counter = start
        while True:
            self._tick()
            if step > 0 and counter > end:
                break
            if step < 0 and counter < end:
                break
            self._env[key] = counter
            self._exec_statements(stmt.body)
            counter = counter + step

    def _exec_do_loop(self, stmt: VbaDoLoopStatement):
        check_before = stmt.condition_position is VbaLoopConditionPosition.PRE
        is_until = stmt.condition_type is VbaLoopConditionType.UNTIL
        while True:
            self._tick()
            if check_before and stmt.condition is not None:
                cond = self._truthy(self._eval(stmt.condition))
                if is_until and cond:
                    break
                if not is_until and not cond:
                    break
            self._exec_statements(stmt.body)
            if not check_before and stmt.condition is not None:
                cond = self._truthy(self._eval(stmt.condition))
                if is_until and cond:
                    break
                if not is_until and not cond:
                    break

    def _tick(self):
        self._iterations += 1
        if self._iterations > self.max_iterations:
            raise _VbaInterpreterError

    def _eval(self, expr) -> _Value:
        if expr is None:
            return None
        if isinstance(expr, VbaStringLiteral):
            return expr.value
        if isinstance(expr, VbaIntegerLiteral):
            return expr.value
        if isinstance(expr, VbaFloatLiteral):
            return expr.value
        if isinstance(expr, VbaBooleanLiteral):
            return expr.value
        if isinstance(expr, VbaIdentifier):
            return self._env.get(expr.name.lower())
        if isinstance(expr, VbaBinaryExpression):
            return self._eval_binary(expr)
        if isinstance(expr, VbaUnaryExpression):
            return self._eval_unary(expr)
        if isinstance(expr, VbaParenExpression):
            return self._eval(expr.expression)
        if isinstance(expr, VbaCallExpression):
            return self._eval_call(expr)
        raise _VbaInterpreterError

    def _eval_binary(self, node: VbaBinaryExpression) -> _Value:
        left = self._eval(node.left)
        right = self._eval(node.right)
        op = node.operator
        if op == '&':
            return self._concat(left, right)
        if op == '+':
            if isinstance(left, str) or isinstance(right, str):
                return self._concat(left, right)
            return self._numeric_op(left, right, lambda a, b: a + b)
        if op == '-':
            return self._numeric_op(left, right, lambda a, b: a - b)
        if op == '*':
            return self._numeric_op(left, right, lambda a, b: a * b)
        if op == '/':
            return self._numeric_op(left, right, lambda a, b: a / b)
        if op == '\\':
            a = self._to_int(left)
            b = self._to_int(right)
            if b == 0:
                raise _VbaInterpreterError
            return a // b
        if op.lower() == 'mod':
            a = self._to_int(left)
            b = self._to_int(right)
            if b == 0:
                raise _VbaInterpreterError
            return a % b
        if op == '^':
            return self._numeric_op(left, right, lambda a, b: a ** b)
        if op.lower() == 'xor':
            return self._to_int(left) ^ self._to_int(right)
        if op.lower() == 'and':
            return self._to_int(left) & self._to_int(right)
        if op.lower() == 'or':
            return self._to_int(left) | self._to_int(right)
        if op == '=':
            return left == right
        if op == '<>':
            return left != right
        if op == '<':
            return self._compare(left, right, lambda a, b: a < b)
        if op == '>':
            return self._compare(left, right, lambda a, b: a > b)
        if op == '<=':
            return self._compare(left, right, lambda a, b: a <= b)
        if op == '>=':
            return self._compare(left, right, lambda a, b: a >= b)
        raise _VbaInterpreterError

    def _eval_unary(self, node: VbaUnaryExpression) -> _Value:
        val = self._eval(node.operand)
        op = node.operator
        if op == '-':
            n = self._to_number(val)
            return -n
        if op.lower() == 'not':
            if isinstance(val, bool):
                return not val
            return ~self._to_int(val)
        raise _VbaInterpreterError

    _BUILTINS: dict[str, Callable[[Any], _Value]] = {
        'chr'       : lambda v: chr(int(v)),
        'chrw'      : lambda v: chr(int(v)),
        'chr$'      : lambda v: chr(int(v)),
        'chrw$'     : lambda v: chr(int(v)),
        'asc'       : lambda v: ord(str(v)[0]),
        'ascw'      : lambda v: ord(str(v)[0]),
        'len'       : lambda v: len(str(v)),
        'lcase'     : lambda v: str(v).lower(),
        'lcase$'    : lambda v: str(v).lower(),
        'ucase'     : lambda v: str(v).upper(),
        'ucase$'    : lambda v: str(v).upper(),
        'trim'      : lambda v: str(v).strip(),
        'trim$'     : lambda v: str(v).strip(),
        'ltrim'     : lambda v: str(v).lstrip(),
        'ltrim$'    : lambda v: str(v).lstrip(),
        'rtrim'     : lambda v: str(v).rstrip(),
        'rtrim$'    : lambda v: str(v).rstrip(),
        'strreverse': lambda v: str(v)[::-1],
        'cstr'      : lambda v: str(v),
        'cint'      : lambda v: int(round(float(v))),
        'clng'      : lambda v: int(round(float(v))),
        'cdbl'      : lambda v: float(v),
        'csng'      : lambda v: float(v),
        'cbool'     : lambda v: bool(v),
        'abs'       : lambda v: abs(v),
        'sgn'       : lambda v: (1 if v > 0 else (-1 if v < 0 else 0)),
        'int'       : _cast_to_int,
        'fix'       : lambda v: int(float(v)),
        'hex'       : lambda v: format(int(v), 'X'),
        'hex$'      : lambda v: format(int(v), 'X'),
        'oct'       : lambda v: format(int(v), 'o'),
        'oct$'      : lambda v: format(int(v), 'o'),
        'cbyte'     : lambda v: int(v) & 0xFF,
        'space'     : lambda v: ' ' * int(v),
        'space$'    : lambda v: ' ' * int(v),
    }

    def _eval_call(self, node: VbaCallExpression) -> _Value:
        if not isinstance(node.callee, VbaIdentifier):
            raise _VbaInterpreterError
        name = node.callee.name.lower()
        args = [self._eval(a) for a in node.arguments if a is not None]
        handler = self._BUILTINS.get(name)
        if handler is not None and len(args) == 1:
            try:
                return handler(args[0])
            except (ValueError, OverflowError, TypeError, IndexError):
                raise _VbaInterpreterError
        if name == 'mid' or name == 'mid$':
            return self._builtin_mid(args)
        if name == 'left' or name == 'left$':
            return self._builtin_left(args)
        if name == 'right' or name == 'right$':
            return self._builtin_right(args)
        if name == 'string' or name == 'string$':
            return self._builtin_string(args)
        if name == 'replace':
            return self._builtin_replace(args)
        if name == 'instr':
            return self._builtin_instr(args)
        raise _VbaInterpreterError

    @staticmethod
    def _builtin_mid(args: list[_Value]) -> str:
        if len(args) not in (2, 3):
            raise _VbaInterpreterError
        s = str(args[0]) if args[0] is not None else ''
        start = int(args[1]) - 1  # type: ignore
        if start < 0:
            raise _VbaInterpreterError
        if len(args) == 3:
            length = int(args[2])  # type: ignore
            return s[start:start + length]
        return s[start:]

    @staticmethod
    def _builtin_left(args: list[_Value]) -> str:
        if len(args) != 2:
            raise _VbaInterpreterError
        s = str(args[0]) if args[0] is not None else ''
        n = int(args[1])  # type: ignore
        return s[:n]

    @staticmethod
    def _builtin_right(args: list[_Value]) -> str:
        if len(args) != 2:
            raise _VbaInterpreterError
        s = str(args[0]) if args[0] is not None else ''
        n = int(args[1])  # type: ignore
        return s[-n:] if n > 0 else ''

    @staticmethod
    def _builtin_string(args: list[_Value]) -> str:
        if len(args) != 2:
            raise _VbaInterpreterError
        n = int(args[0])  # type: ignore
        c = str(args[1]) if args[1] is not None else ''
        if not c:
            raise _VbaInterpreterError
        return c[0] * n

    @staticmethod
    def _builtin_replace(args: list[_Value]) -> str:
        if len(args) < 3:
            raise _VbaInterpreterError
        haystack = str(args[0]) if args[0] is not None else ''
        needle = str(args[1]) if args[1] is not None else ''
        insert = str(args[2]) if args[2] is not None else ''
        if not needle:
            raise _VbaInterpreterError
        return haystack.replace(needle, insert)

    @staticmethod
    def _builtin_instr(args: list[_Value]) -> int:
        if len(args) == 2:
            haystack = str(args[0]) if args[0] is not None else ''
            needle = str(args[1]) if args[1] is not None else ''
            idx = haystack.find(needle)
            return idx + 1 if idx >= 0 else 0
        if len(args) == 3:
            start = int(args[0])  # type: ignore
            haystack = str(args[1]) if args[1] is not None else ''
            needle = str(args[2]) if args[2] is not None else ''
            idx = haystack.find(needle, start - 1)
            return idx + 1 if idx >= 0 else 0
        raise _VbaInterpreterError

    def _concat(self, lhs: _Value, rhs: _Value) -> str:
        a = str(lhs) if lhs is not None else ''
        b = str(rhs) if rhs is not None else ''
        result = a + b
        if len(result) > self.max_string_length:
            raise _VbaInterpreterError
        return result

    @staticmethod
    def _to_number(v: _Value) -> int | float:
        if v is None:
            return 0
        if isinstance(v, bool):
            return -1 if v else 0
        if isinstance(v, (int, float)):
            return v
        if isinstance(v, str):
            try:
                return int(v)
            except ValueError:
                try:
                    return float(v)
                except ValueError:
                    raise _VbaInterpreterError

    @staticmethod
    def _to_int(v: _Value) -> int:
        if v is None:
            return 0
        if isinstance(v, bool):
            return -1 if v else 0
        if isinstance(v, int):
            return v
        if isinstance(v, float):
            return int(v)
        if isinstance(v, str):
            try:
                return int(v)
            except ValueError:
                raise _VbaInterpreterError
        raise _VbaInterpreterError

    def _numeric_op(self, left: _Value, right: _Value, op) -> int | float:
        a = self._to_number(left)
        b = self._to_number(right)
        try:
            result = op(a, b)
        except (ZeroDivisionError, ValueError, OverflowError, ArithmeticError):
            raise _VbaInterpreterError
        if isinstance(result, float) and (result != result or abs(result) == float('inf')):
            raise _VbaInterpreterError
        return result

    @staticmethod
    def _compare(left: _Value, right: _Value, op) -> bool:
        if isinstance(left, str) and isinstance(right, str):
            return op(left.lower(), right.lower())
        if isinstance(left, (int, float)) and isinstance(right, (int, float)):
            return op(left, right)
        raise _VbaInterpreterError

    @staticmethod
    def _truthy(value: _Value) -> bool:
        if value is None:
            return False
        if isinstance(value, bool):
            return value
        if isinstance(value, int):
            return value != 0
        if isinstance(value, float):
            return value != 0.0
        if isinstance(value, str):
            return len(value) > 0
        return True


class VbaFunctionEvaluator(Transformer):
    """
    Evaluate calls to user-defined VBA functions when all arguments are constants.
    Replaces the call expression with the computed string or integer literal.
    Removes function definitions once all their calls have been resolved.
    """

    def __init__(
        self,
        max_iterations: int = 100_000,
        max_string_length: int = 1_000_000,
    ):
        super().__init__()
        self.max_iterations = max_iterations
        self.max_string_length = max_string_length
        self._functions: dict[str, VbaFunctionDeclaration] = {}
        self._call_counts: dict[str, int] = {}
        self._replaced_counts: dict[str, int] = {}
        self._entry = False
        self._inside_function: str | None = None

    def visit(self, node):
        if self._entry:
            return super().visit(node)
        self._entry = True
        try:
            self._functions.clear()
            self._call_counts.clear()
            self._replaced_counts.clear()
            self._collect_functions(node)
            if not self._functions:
                return None
            super().visit(node)
            self._remove_resolved_definitions(node)
            return None
        finally:
            self._entry = False

    def _collect_functions(self, root):
        for node in root.walk():
            if isinstance(node, VbaFunctionDeclaration):
                if not node.name:
                    continue
                self._functions[node.name.lower()] = node

    def visit_VbaFunctionDeclaration(self, node: VbaFunctionDeclaration):
        key = node.name.lower() if node.name else None
        old = self._inside_function
        self._inside_function = key
        self.generic_visit(node)
        self._inside_function = old
        return None

    def visit_VbaCallExpression(self, node: VbaCallExpression):
        self.generic_visit(node)
        if not isinstance(node.callee, VbaIdentifier):
            return None
        key = node.callee.name.lower()
        funcdef = self._functions.get(key)
        if funcdef is None:
            return None
        self._call_counts[key] = self._call_counts.get(key, 0) + 1
        args = self._extract_constant_args(node)
        if args is None:
            return None
        bindings = self._bind_parameters(funcdef, args)
        if bindings is None:
            return None
        result = self._try_evaluate(funcdef, bindings)
        if result is None:
            return None
        replacement = _value_to_node(result)
        if replacement is None:
            return None
        self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
        return replacement

    def visit_VbaIdentifier(self, node: VbaIdentifier):
        key = node.name.lower()
        if key == self._inside_function:
            return None
        funcdef = self._functions.get(key)
        if funcdef is None:
            return None
        if funcdef.params:
            required = [p for p in funcdef.params if not p.is_optional and p.default is None]
            if required:
                return None
        parent = node.parent
        if isinstance(parent, VbaLetStatement) and parent.target is node:
            return None
        if isinstance(parent, VbaCallExpression) and parent.callee is node:
            return None
        self._call_counts[key] = self._call_counts.get(key, 0) + 1
        bindings = self._bind_parameters(funcdef, [])
        if bindings is None:
            return None
        result = self._try_evaluate(funcdef, bindings)
        if result is None:
            return None
        replacement = _value_to_node(result)
        if replacement is None:
            return None
        self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
        return replacement

    def _try_evaluate(
        self,
        funcdef: VbaFunctionDeclaration,
        bindings: dict[str, _Value],
    ) -> _Value:
        interpreter = _VbaInterpreter(
            function_name=funcdef.name,
            max_iterations=self.max_iterations,
            max_string_length=self.max_string_length,
        )
        try:
            return interpreter.execute(funcdef.body, bindings)
        except (_VbaInterpreterError, _UnevaluableError):
            return None

    @staticmethod
    def _extract_constant_args(node: VbaCallExpression) -> list[_Value] | None:
        args: list[_Value] = []
        for arg in node.arguments:
            if arg is None:
                args.append(None)
                continue
            if not _is_literal(arg):
                return None
            args.append(_literal_value(arg))
        return args

    @staticmethod
    def _bind_parameters(
        funcdef: VbaFunctionDeclaration,
        args: list[_Value],
    ) -> dict[str, _Value] | None:
        bindings: dict[str, _Value] = {}
        for i, param in enumerate(funcdef.params):
            key = param.name.lower()
            if i < len(args):
                bindings[key] = args[i]
            elif param.is_optional and param.default is not None:
                if _is_literal(param.default):
                    bindings[key] = _literal_value(param.default)
                else:
                    return None
            elif param.is_optional:
                bindings[key] = None
            else:
                return None
        return bindings

    def _remove_resolved_definitions(self, _root):
        for key, funcdef in self._functions.items():
            call_count = self._call_counts.get(key, 0)
            replaced_count = self._replaced_counts.get(key, 0)
            if call_count == 0 or replaced_count < call_count:
                continue
            parent = funcdef.parent
            if parent is None:
                continue
            if isinstance(parent, VbaModule):
                body = parent.body
            else:
                continue
            if funcdef in body:
                body.remove(funcdef)
                self.mark_changed()

Classes

class VbaFunctionEvaluator (max_iterations=100000, max_string_length=1000000)

Evaluate calls to user-defined VBA functions when all arguments are constants. Replaces the call expression with the computed string or integer literal. Removes function definitions once all their calls have been resolved.

Expand source code Browse git
class VbaFunctionEvaluator(Transformer):
    """
    Evaluate calls to user-defined VBA functions when all arguments are constants.
    Replaces the call expression with the computed string or integer literal.
    Removes function definitions once all their calls have been resolved.
    """

    def __init__(
        self,
        max_iterations: int = 100_000,
        max_string_length: int = 1_000_000,
    ):
        super().__init__()
        self.max_iterations = max_iterations
        self.max_string_length = max_string_length
        self._functions: dict[str, VbaFunctionDeclaration] = {}
        self._call_counts: dict[str, int] = {}
        self._replaced_counts: dict[str, int] = {}
        self._entry = False
        self._inside_function: str | None = None

    def visit(self, node):
        if self._entry:
            return super().visit(node)
        self._entry = True
        try:
            self._functions.clear()
            self._call_counts.clear()
            self._replaced_counts.clear()
            self._collect_functions(node)
            if not self._functions:
                return None
            super().visit(node)
            self._remove_resolved_definitions(node)
            return None
        finally:
            self._entry = False

    def _collect_functions(self, root):
        for node in root.walk():
            if isinstance(node, VbaFunctionDeclaration):
                if not node.name:
                    continue
                self._functions[node.name.lower()] = node

    def visit_VbaFunctionDeclaration(self, node: VbaFunctionDeclaration):
        key = node.name.lower() if node.name else None
        old = self._inside_function
        self._inside_function = key
        self.generic_visit(node)
        self._inside_function = old
        return None

    def visit_VbaCallExpression(self, node: VbaCallExpression):
        self.generic_visit(node)
        if not isinstance(node.callee, VbaIdentifier):
            return None
        key = node.callee.name.lower()
        funcdef = self._functions.get(key)
        if funcdef is None:
            return None
        self._call_counts[key] = self._call_counts.get(key, 0) + 1
        args = self._extract_constant_args(node)
        if args is None:
            return None
        bindings = self._bind_parameters(funcdef, args)
        if bindings is None:
            return None
        result = self._try_evaluate(funcdef, bindings)
        if result is None:
            return None
        replacement = _value_to_node(result)
        if replacement is None:
            return None
        self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
        return replacement

    def visit_VbaIdentifier(self, node: VbaIdentifier):
        key = node.name.lower()
        if key == self._inside_function:
            return None
        funcdef = self._functions.get(key)
        if funcdef is None:
            return None
        if funcdef.params:
            required = [p for p in funcdef.params if not p.is_optional and p.default is None]
            if required:
                return None
        parent = node.parent
        if isinstance(parent, VbaLetStatement) and parent.target is node:
            return None
        if isinstance(parent, VbaCallExpression) and parent.callee is node:
            return None
        self._call_counts[key] = self._call_counts.get(key, 0) + 1
        bindings = self._bind_parameters(funcdef, [])
        if bindings is None:
            return None
        result = self._try_evaluate(funcdef, bindings)
        if result is None:
            return None
        replacement = _value_to_node(result)
        if replacement is None:
            return None
        self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
        return replacement

    def _try_evaluate(
        self,
        funcdef: VbaFunctionDeclaration,
        bindings: dict[str, _Value],
    ) -> _Value:
        interpreter = _VbaInterpreter(
            function_name=funcdef.name,
            max_iterations=self.max_iterations,
            max_string_length=self.max_string_length,
        )
        try:
            return interpreter.execute(funcdef.body, bindings)
        except (_VbaInterpreterError, _UnevaluableError):
            return None

    @staticmethod
    def _extract_constant_args(node: VbaCallExpression) -> list[_Value] | None:
        args: list[_Value] = []
        for arg in node.arguments:
            if arg is None:
                args.append(None)
                continue
            if not _is_literal(arg):
                return None
            args.append(_literal_value(arg))
        return args

    @staticmethod
    def _bind_parameters(
        funcdef: VbaFunctionDeclaration,
        args: list[_Value],
    ) -> dict[str, _Value] | None:
        bindings: dict[str, _Value] = {}
        for i, param in enumerate(funcdef.params):
            key = param.name.lower()
            if i < len(args):
                bindings[key] = args[i]
            elif param.is_optional and param.default is not None:
                if _is_literal(param.default):
                    bindings[key] = _literal_value(param.default)
                else:
                    return None
            elif param.is_optional:
                bindings[key] = None
            else:
                return None
        return bindings

    def _remove_resolved_definitions(self, _root):
        for key, funcdef in self._functions.items():
            call_count = self._call_counts.get(key, 0)
            replaced_count = self._replaced_counts.get(key, 0)
            if call_count == 0 or replaced_count < call_count:
                continue
            parent = funcdef.parent
            if parent is None:
                continue
            if isinstance(parent, VbaModule):
                body = parent.body
            else:
                continue
            if funcdef in body:
                body.remove(funcdef)
                self.mark_changed()

Ancestors

Methods

def visit(self, node)
Expand source code Browse git
def visit(self, node):
    if self._entry:
        return super().visit(node)
    self._entry = True
    try:
        self._functions.clear()
        self._call_counts.clear()
        self._replaced_counts.clear()
        self._collect_functions(node)
        if not self._functions:
            return None
        super().visit(node)
        self._remove_resolved_definitions(node)
        return None
    finally:
        self._entry = False
def visit_VbaFunctionDeclaration(self, node)
Expand source code Browse git
def visit_VbaFunctionDeclaration(self, node: VbaFunctionDeclaration):
    key = node.name.lower() if node.name else None
    old = self._inside_function
    self._inside_function = key
    self.generic_visit(node)
    self._inside_function = old
    return None
def visit_VbaCallExpression(self, node)
Expand source code Browse git
def visit_VbaCallExpression(self, node: VbaCallExpression):
    self.generic_visit(node)
    if not isinstance(node.callee, VbaIdentifier):
        return None
    key = node.callee.name.lower()
    funcdef = self._functions.get(key)
    if funcdef is None:
        return None
    self._call_counts[key] = self._call_counts.get(key, 0) + 1
    args = self._extract_constant_args(node)
    if args is None:
        return None
    bindings = self._bind_parameters(funcdef, args)
    if bindings is None:
        return None
    result = self._try_evaluate(funcdef, bindings)
    if result is None:
        return None
    replacement = _value_to_node(result)
    if replacement is None:
        return None
    self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
    return replacement
def visit_VbaIdentifier(self, node)
Expand source code Browse git
def visit_VbaIdentifier(self, node: VbaIdentifier):
    key = node.name.lower()
    if key == self._inside_function:
        return None
    funcdef = self._functions.get(key)
    if funcdef is None:
        return None
    if funcdef.params:
        required = [p for p in funcdef.params if not p.is_optional and p.default is None]
        if required:
            return None
    parent = node.parent
    if isinstance(parent, VbaLetStatement) and parent.target is node:
        return None
    if isinstance(parent, VbaCallExpression) and parent.callee is node:
        return None
    self._call_counts[key] = self._call_counts.get(key, 0) + 1
    bindings = self._bind_parameters(funcdef, [])
    if bindings is None:
        return None
    result = self._try_evaluate(funcdef, bindings)
    if result is None:
        return None
    replacement = _value_to_node(result)
    if replacement is None:
        return None
    self._replaced_counts[key] = self._replaced_counts.get(key, 0) + 1
    return replacement