Module refinery.lib.scripts.vba.deobfuscation.deadcode

VBA dead variable removal: removes assignments to variables that are never read, provided the right-hand side has no side effects.

Expand source code Browse git
"""
VBA dead variable removal: removes assignments to variables that are never read, provided the
right-hand side has no side effects.
"""
from __future__ import annotations

from refinery.lib.scripts import Statement, Transformer
from refinery.lib.scripts.vba.deobfuscation._helpers import _body_lists
from refinery.lib.scripts.vba.model import (
    VbaCallExpression,
    VbaIdentifier,
    VbaLetStatement,
    VbaModule,
)

_PURE_BUILTINS = frozenset({
    'abs',
    'atn',
    'cbool',
    'cbyte',
    'ccur',
    'cdate',
    'cdbl',
    'cdec',
    'chr',
    'chr$',
    'chrw',
    'chrw$',
    'cint',
    'clng',
    'clnglng',
    'clngptr',
    'cos',
    'csng',
    'cstr',
    'cvar',
    'exp',
    'fix',
    'hex',
    'hex$',
    'int',
    'log',
    'oct',
    'oct$',
    'sgn',
    'sin',
    'sqr',
    'str',
    'str$',
    'tan',
    'val',
})


def _has_side_effects(node) -> bool:
    """
    Return whether an expression tree might have side effects. Calls to known
    pure VBA builtins are treated as side-effect-free.
    """
    for child in node.walk():
        if not isinstance(child, VbaCallExpression):
            continue
        if not isinstance(child.callee, VbaIdentifier):
            return True
        if child.callee.name.lower() not in _PURE_BUILTINS:
            return True
    return False


class VbaDeadVariableRemoval(Transformer):

    def visit(self, node):
        if isinstance(node, VbaModule):
            if self._remove_dead_variables(node):
                self.mark_changed()
        return None

    def _remove_dead_variables(self, module: VbaModule) -> bool:
        assignments: dict[str, list[tuple[VbaLetStatement, list[Statement], int]]] = {}
        for body in _body_lists(module):
            if body is module.body:
                continue
            for idx, stmt in enumerate(body):
                if (
                    isinstance(stmt, VbaLetStatement)
                    and isinstance(stmt.target, VbaIdentifier)
                    and stmt.value is not None
                ):
                    if not _has_side_effects(stmt.value):
                        key = stmt.target.name.lower()
                        assignments.setdefault(key, []).append((stmt, body, idx))
        read_names: set[str] = set()
        for node in module.walk():
            if not isinstance(node, VbaIdentifier):
                continue
            parent = node.parent
            if isinstance(parent, VbaLetStatement) and parent.target is node:
                continue
            read_names.add(node.name.lower())
        removals: list[tuple[list[Statement], int]] = []
        for key, entries in assignments.items():
            if key not in read_names:
                for _stmt, body, idx in entries:
                    removals.append((body, idx))
        for body, idx in sorted(removals, key=lambda t: t[1], reverse=True):
            del body[idx]
        return bool(removals)

Classes

class VbaDeadVariableRemoval

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class VbaDeadVariableRemoval(Transformer):

    def visit(self, node):
        if isinstance(node, VbaModule):
            if self._remove_dead_variables(node):
                self.mark_changed()
        return None

    def _remove_dead_variables(self, module: VbaModule) -> bool:
        assignments: dict[str, list[tuple[VbaLetStatement, list[Statement], int]]] = {}
        for body in _body_lists(module):
            if body is module.body:
                continue
            for idx, stmt in enumerate(body):
                if (
                    isinstance(stmt, VbaLetStatement)
                    and isinstance(stmt.target, VbaIdentifier)
                    and stmt.value is not None
                ):
                    if not _has_side_effects(stmt.value):
                        key = stmt.target.name.lower()
                        assignments.setdefault(key, []).append((stmt, body, idx))
        read_names: set[str] = set()
        for node in module.walk():
            if not isinstance(node, VbaIdentifier):
                continue
            parent = node.parent
            if isinstance(parent, VbaLetStatement) and parent.target is node:
                continue
            read_names.add(node.name.lower())
        removals: list[tuple[list[Statement], int]] = []
        for key, entries in assignments.items():
            if key not in read_names:
                for _stmt, body, idx in entries:
                    removals.append((body, idx))
        for body, idx in sorted(removals, key=lambda t: t[1], reverse=True):
            del body[idx]
        return bool(removals)

Ancestors

Methods

def visit(self, node)
Expand source code Browse git
def visit(self, node):
    if isinstance(node, VbaModule):
        if self._remove_dead_variables(node):
            self.mark_changed()
    return None