Module refinery.lib.scripts.vba.deobfuscation.deadcode

VBA dead code removal: removes assignments to unread variables and empty uncalled procedures.

Expand source code Browse git
"""
VBA dead code removal: removes assignments to unread variables and empty uncalled procedures.
"""
from __future__ import annotations

from refinery.lib.scripts import Statement, Transformer
from refinery.lib.scripts.vba.deobfuscation.helpers import apply_removals, body_lists
from refinery.lib.scripts.vba.deobfuscation.names import SINGLE_ARG_BUILTINS, STRING_BUILTINS
from refinery.lib.scripts.vba.model import (
    VbaCallExpression,
    VbaFunctionDeclaration,
    VbaIdentifier,
    VbaLetStatement,
    VbaModule,
    VbaProcedureDeclaration,
    VbaPropertyDeclaration,
)

_PURE_BUT_UNEVALUABLE = frozenset({
    'atn',
    'ccur',
    'cdate',
    'cdec',
    'clnglng',
    'clngptr',
    'cos',
    'csng',
    'cvar',
    'exp',
    'log',
    'sin',
    'sqr',
    'str',
    'str$',
    'tan',
    'val',
})

_PURE_BUILTINS = frozenset(SINGLE_ARG_BUILTINS) | STRING_BUILTINS | _PURE_BUT_UNEVALUABLE


def _has_side_effects(node) -> bool:
    """
    Return whether an expression tree might have side effects. Calls to known
    pure VBA builtins are treated as side-effect-free.
    """
    for child in node.walk():
        if not isinstance(child, VbaCallExpression):
            continue
        if not isinstance(child.callee, VbaIdentifier):
            return True
        if child.callee.name.lower() not in _PURE_BUILTINS:
            return True
    return False


def _enclosing_return_name(node) -> str | None:
    """
    Return the lowercased name of the Function or Property that `node` returns into, i.e. the name
    of its innermost enclosing procedure when that procedure is a Function or Property (whose return
    value is produced by assigning to its own name). Returns `None` when the innermost procedure is a
    Sub or there is no enclosing procedure.
    """
    parent = node.parent
    while parent is not None:
        if isinstance(parent, VbaProcedureDeclaration):
            if isinstance(parent, (VbaFunctionDeclaration, VbaPropertyDeclaration)) and parent.name:
                return parent.name.lower()
            return None
        parent = parent.parent
    return None


class VbaDeadVariableRemoval(Transformer):

    def visit(self, node):
        if isinstance(node, VbaModule):
            if self._remove_dead_variables(node):
                self.mark_changed()
        return None

    def _remove_dead_variables(self, module: VbaModule) -> bool:
        assignments: dict[str, list[tuple[VbaLetStatement, list[Statement], int]]] = {}
        for body in body_lists(module):
            if body is module.body:
                continue
            for idx, stmt in enumerate(body):
                if (
                    isinstance(stmt, VbaLetStatement)
                    and isinstance(stmt.target, VbaIdentifier)
                    and stmt.value is not None
                ):
                    if not _has_side_effects(stmt.value):
                        key = stmt.target.name.lower()
                        if key == _enclosing_return_name(stmt):
                            continue
                        assignments.setdefault(key, []).append((stmt, body, idx))
        read_names: set[str] = set()
        for node in module.walk():
            if not isinstance(node, VbaIdentifier):
                continue
            if isinstance(node.parent, VbaLetStatement) and node.parent.target is node:
                continue
            read_names.add(node.name.lower())
        removals: list[tuple[int, list[Statement]]] = []
        for key, entries in assignments.items():
            if key not in read_names:
                for _, body, idx in entries:
                    removals.append((idx, body))
        return apply_removals(removals)


class VbaEmptyProcedureRemoval(Transformer):

    def visit(self, node):
        if isinstance(node, VbaModule):
            if self._remove_empty_procedures(node):
                self.mark_changed()
        return None

    def _remove_empty_procedures(self, module: VbaModule) -> bool:
        empty: dict[str, list[int]] = {}
        for idx, stmt in enumerate(module.body):
            if isinstance(stmt, VbaProcedureDeclaration) and not stmt.body:
                empty.setdefault(stmt.name.lower(), []).append(idx)
        if not empty:
            return False
        referenced: set[str] = set()
        for node in module.walk():
            if isinstance(node, VbaIdentifier):
                key = node.name.lower()
                if key in empty:
                    referenced.add(key)
        removals: list[tuple[int, list[Statement]]] = []
        for key, positions in empty.items():
            if key not in referenced:
                for idx in positions:
                    removals.append((idx, module.body))
        return apply_removals(removals)

Classes

class VbaDeadVariableRemoval

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class VbaDeadVariableRemoval(Transformer):

    def visit(self, node):
        if isinstance(node, VbaModule):
            if self._remove_dead_variables(node):
                self.mark_changed()
        return None

    def _remove_dead_variables(self, module: VbaModule) -> bool:
        assignments: dict[str, list[tuple[VbaLetStatement, list[Statement], int]]] = {}
        for body in body_lists(module):
            if body is module.body:
                continue
            for idx, stmt in enumerate(body):
                if (
                    isinstance(stmt, VbaLetStatement)
                    and isinstance(stmt.target, VbaIdentifier)
                    and stmt.value is not None
                ):
                    if not _has_side_effects(stmt.value):
                        key = stmt.target.name.lower()
                        if key == _enclosing_return_name(stmt):
                            continue
                        assignments.setdefault(key, []).append((stmt, body, idx))
        read_names: set[str] = set()
        for node in module.walk():
            if not isinstance(node, VbaIdentifier):
                continue
            if isinstance(node.parent, VbaLetStatement) and node.parent.target is node:
                continue
            read_names.add(node.name.lower())
        removals: list[tuple[int, list[Statement]]] = []
        for key, entries in assignments.items():
            if key not in read_names:
                for _, body, idx in entries:
                    removals.append((idx, body))
        return apply_removals(removals)

Ancestors

Methods

def visit(self, node)
Expand source code Browse git
def visit(self, node):
    if isinstance(node, VbaModule):
        if self._remove_dead_variables(node):
            self.mark_changed()
    return None

Inherited members

class VbaEmptyProcedureRemoval

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class VbaEmptyProcedureRemoval(Transformer):

    def visit(self, node):
        if isinstance(node, VbaModule):
            if self._remove_empty_procedures(node):
                self.mark_changed()
        return None

    def _remove_empty_procedures(self, module: VbaModule) -> bool:
        empty: dict[str, list[int]] = {}
        for idx, stmt in enumerate(module.body):
            if isinstance(stmt, VbaProcedureDeclaration) and not stmt.body:
                empty.setdefault(stmt.name.lower(), []).append(idx)
        if not empty:
            return False
        referenced: set[str] = set()
        for node in module.walk():
            if isinstance(node, VbaIdentifier):
                key = node.name.lower()
                if key in empty:
                    referenced.add(key)
        removals: list[tuple[int, list[Statement]]] = []
        for key, positions in empty.items():
            if key not in referenced:
                for idx in positions:
                    removals.append((idx, module.body))
        return apply_removals(removals)

Ancestors

Methods

def visit(self, node)
Expand source code Browse git
def visit(self, node):
    if isinstance(node, VbaModule):
        if self._remove_empty_procedures(node):
            self.mark_changed()
    return None

Inherited members