Module refinery.lib.scripts.js.deobfuscation.wrappers

Inline trivial function call wrappers.

A call wrapper is a small function whose only purpose is to forward a call to another function after rearranging or arithmetically transforming its arguments. This is a common obfuscation technique that adds a layer of indirection around every call site. The transformer detects these wrappers and substitutes each call site with the inlined return expression.

Expand source code Browse git
"""
Inline trivial function call wrappers.

A call wrapper is a small function whose only purpose is to forward a call to another function
after rearranging or arithmetically transforming its arguments. This is a common obfuscation
technique that adds a layer of indirection around every call site. The transformer detects these
wrappers and substitutes each call site with the inlined return expression.
"""
from __future__ import annotations

from refinery.lib.scripts import (
    Node,
    Transformer,
    _remove_from_parent,
    _replace_in_parent,
)
from refinery.lib.scripts.js.deobfuscation.helpers import (
    extract_identifier_params,
    is_closed_expression,
    is_simple_expression,
    substitute_params,
)
from refinery.lib.scripts.js.model import (
    JsCallExpression,
    JsFunctionDeclaration,
    JsIdentifier,
    JsReturnStatement,
    JsScript,
)

from typing import NamedTuple


class _WrapperInfo(NamedTuple):
    """
    Describes a detected call wrapper function.
    """
    node: JsFunctionDeclaration
    name: str
    param_names: list[str]
    return_expression: Node


def _detect_wrapper(node: JsFunctionDeclaration) -> _WrapperInfo | None:
    """
    Test whether a function declaration is a call wrapper. A call wrapper has one or more
    identifier parameters, a body consisting of a single return statement, and the returned
    expression is a call whose argument sub-expressions reference only the wrapper's parameters
    and literal constants.
    """
    if node.id is None or node.body is None:
        return None
    if not node.params:
        return None
    param_names = extract_identifier_params(node.params)
    if param_names is None:
        return None
    body = node.body.body
    if len(body) != 1:
        return None
    stmt = body[0]
    if not isinstance(stmt, JsReturnStatement) or stmt.argument is None:
        return None
    call = stmt.argument
    if not isinstance(call, JsCallExpression):
        return None
    if not isinstance(call.callee, JsIdentifier):
        return None
    allowed_names = set(param_names)
    allowed_names.add(call.callee.name)
    for arg in call.arguments:
        if not is_closed_expression(arg, allowed_names):
            return None
    return _WrapperInfo(node, node.id.name, param_names, call)


def _collect_wrappers(root: Node) -> dict[str, _WrapperInfo]:
    """
    Walk the entire AST and collect all function declarations that qualify as call wrappers.
    """
    wrappers: dict[str, _WrapperInfo] = {}
    for node in root.walk():
        if isinstance(node, JsFunctionDeclaration):
            info = _detect_wrapper(node)
            if info is not None:
                wrappers[info.name] = info
    return wrappers


class JsCallWrapperInliner(Transformer):
    """
    Detect trivial call wrapper functions and inline them at every call site.
    """

    def visit_JsScript(self, node: JsScript):
        wrappers = _collect_wrappers(node)
        if not wrappers:
            return None
        inlined = False
        for ast_node in list(node.walk()):
            if not isinstance(ast_node, JsCallExpression):
                continue
            if not isinstance(ast_node.callee, JsIdentifier):
                continue
            info = wrappers.get(ast_node.callee.name)
            if info is None:
                continue
            if len(ast_node.arguments) != len(info.param_names):
                continue
            if not all(is_simple_expression(a) for a in ast_node.arguments):
                continue
            replacement = substitute_params(
                info.return_expression,
                info.param_names,
                ast_node.arguments,
            )
            _replace_in_parent(ast_node, replacement)
            inlined = True
        if not inlined:
            return None
        exclude_ids: set[int] = set()
        for info in wrappers.values():
            for n in info.node.walk():
                exclude_ids.add(id(n))
        referenced: set[str] = set()
        for n in node.walk():
            if id(n) in exclude_ids:
                continue
            if isinstance(n, JsIdentifier) and n.name in wrappers:
                referenced.add(n.name)
        for name, info in wrappers.items():
            if name not in referenced:
                _remove_from_parent(info.node)
        self.mark_changed()
        return None

    def generic_visit(self, node: Node):
        pass

Classes

class JsCallWrapperInliner

Detect trivial call wrapper functions and inline them at every call site.

Expand source code Browse git
class JsCallWrapperInliner(Transformer):
    """
    Detect trivial call wrapper functions and inline them at every call site.
    """

    def visit_JsScript(self, node: JsScript):
        wrappers = _collect_wrappers(node)
        if not wrappers:
            return None
        inlined = False
        for ast_node in list(node.walk()):
            if not isinstance(ast_node, JsCallExpression):
                continue
            if not isinstance(ast_node.callee, JsIdentifier):
                continue
            info = wrappers.get(ast_node.callee.name)
            if info is None:
                continue
            if len(ast_node.arguments) != len(info.param_names):
                continue
            if not all(is_simple_expression(a) for a in ast_node.arguments):
                continue
            replacement = substitute_params(
                info.return_expression,
                info.param_names,
                ast_node.arguments,
            )
            _replace_in_parent(ast_node, replacement)
            inlined = True
        if not inlined:
            return None
        exclude_ids: set[int] = set()
        for info in wrappers.values():
            for n in info.node.walk():
                exclude_ids.add(id(n))
        referenced: set[str] = set()
        for n in node.walk():
            if id(n) in exclude_ids:
                continue
            if isinstance(n, JsIdentifier) and n.name in wrappers:
                referenced.add(n.name)
        for name, info in wrappers.items():
            if name not in referenced:
                _remove_from_parent(info.node)
        self.mark_changed()
        return None

    def generic_visit(self, node: Node):
        pass

Ancestors

Methods

def visit_JsScript(self, node)
Expand source code Browse git
def visit_JsScript(self, node: JsScript):
    wrappers = _collect_wrappers(node)
    if not wrappers:
        return None
    inlined = False
    for ast_node in list(node.walk()):
        if not isinstance(ast_node, JsCallExpression):
            continue
        if not isinstance(ast_node.callee, JsIdentifier):
            continue
        info = wrappers.get(ast_node.callee.name)
        if info is None:
            continue
        if len(ast_node.arguments) != len(info.param_names):
            continue
        if not all(is_simple_expression(a) for a in ast_node.arguments):
            continue
        replacement = substitute_params(
            info.return_expression,
            info.param_names,
            ast_node.arguments,
        )
        _replace_in_parent(ast_node, replacement)
        inlined = True
    if not inlined:
        return None
    exclude_ids: set[int] = set()
    for info in wrappers.values():
        for n in info.node.walk():
            exclude_ids.add(id(n))
    referenced: set[str] = set()
    for n in node.walk():
        if id(n) in exclude_ids:
            continue
        if isinstance(n, JsIdentifier) and n.name in wrappers:
            referenced.add(n.name)
    for name, info in wrappers.items():
        if name not in referenced:
            _remove_from_parent(info.node)
    self.mark_changed()
    return None
def generic_visit(self, node)
Expand source code Browse git
def generic_visit(self, node: Node):
    pass