Module refinery.lib.scripts.js.deobfuscation.restunpack

Unpack rest-parameter arrays that pack multiple variables into a single parameter.

Some obfuscation transforms replace named function parameters and locals with indexed accesses on a single rest parameter array:

function(...stack) { stack.length = N; ... }

This transformer detects the pattern, builds a variable map from collected access keys, and replaces indexed accesses with fresh named identifiers.

Expand source code Browse git
"""
Unpack rest-parameter arrays that pack multiple variables into a single parameter.

Some obfuscation transforms replace named function parameters and locals with indexed accesses
on a single rest parameter array:

    function(...stack) { stack.length = N; ... }

This transformer detects the pattern, builds a variable map from collected access keys, and
replaces indexed accesses with fresh named identifiers.
"""
from __future__ import annotations

from typing import NamedTuple

from refinery.lib.scripts import Node, _replace_in_parent
from refinery.lib.scripts.js.deobfuscation.helpers import (
    ScriptLevelTransformer,
    member_key,
    numeric_value,
)
from refinery.lib.scripts.js.model import (
    JsAssignmentExpression,
    JsBlockStatement,
    JsExpressionStatement,
    JsFunctionDeclaration,
    JsFunctionExpression,
    JsIdentifier,
    JsMemberExpression,
    JsNumericLiteral,
    JsRestElement,
    JsScript,
    JsStringLiteral,
    JsUnaryExpression,
    JsVariableDeclaration,
    JsVariableDeclarator,
    JsVarKind,
)


class _TruncationInfo(NamedTuple):
    param_count: int
    stack_chain: str | None


class _NestedFrameAccess(Exception):
    pass


def _extract_truncation(
    stmts: list,
    rest_name: str,
) -> _TruncationInfo | None:
    """
    Find the `.length = N` truncation statement in the function body. Returns the param count
    and the stack chain key (None for simple case where rest param IS the stack). Returns None
    if no truncation pattern is found.
    """
    for stmt in stmts:
        if not isinstance(stmt, JsExpressionStatement):
            continue
        expr = stmt.expression
        if not isinstance(expr, JsAssignmentExpression) or expr.operator != '=':
            continue
        lhs = expr.left
        if not isinstance(lhs, JsMemberExpression):
            continue
        if lhs.computed:
            continue
        if not isinstance(lhs.property, JsIdentifier) or lhs.property.name != 'length':
            continue
        rhs = expr.right
        if rhs is None:
            continue
        n = numeric_value(rhs)
        if n is None or not isinstance(n, int) or n < 0:
            continue
        obj = lhs.object
        if isinstance(obj, JsIdentifier) and obj.name == rest_name:
            return _TruncationInfo(n, None)
        if isinstance(obj, JsMemberExpression):
            chain = member_key(obj)
            if chain is not None:
                return _TruncationInfo(n, chain)
    return None


def _collect_accesses_simple(
    body: JsBlockStatement,
    rest_name: str,
) -> dict[str, list[JsMemberExpression]] | None:
    """
    Collect all `restParam[key]` and `restParam.key` accesses in the immediate function body
    (not descending into nested functions). Returns a map from string key to list of AST nodes.
    Returns None if the rest param is used in a way that prevents demasking.
    """
    accesses: dict[str, list[JsMemberExpression]] = {}
    if not _walk_collect_simple(body, rest_name, accesses):
        return None
    return accesses


def _walk_collect_simple(
    node: Node,
    rest_name: str,
    accesses: dict[str, list[JsMemberExpression]],
) -> bool:
    for child in node.children():
        if isinstance(child, (JsFunctionExpression, JsFunctionDeclaration)):
            continue
        if isinstance(child, JsMemberExpression):
            obj = child.object
            if isinstance(obj, JsIdentifier) and obj.name == rest_name:
                if (
                    not child.computed
                    and isinstance(child.property, JsIdentifier)
                    and child.property.name == 'length'
                ):
                    continue
                key = _extract_access_key(child)
                if key is None:
                    return False
                accesses.setdefault(key, []).append(child)
                continue
        if isinstance(child, JsIdentifier) and child.name == rest_name:
            parent = child.parent
            if isinstance(parent, JsMemberExpression) and parent.object is child:
                continue
            return False
        if not _walk_collect_simple(child, rest_name, accesses):
            return False
    return True


def _collect_accesses_frame(
    body: JsBlockStatement,
    stack_chain: str,
) -> dict[str, list[JsMemberExpression]] | None:
    """
    Collect all accesses to the frame-qualified stack chain. Returns None if any access exists
    inside a nested function (closure capture prevents demasking).
    """
    accesses: dict[str, list[JsMemberExpression]] = {}
    try:
        _walk_collect_frame(body, stack_chain, accesses, depth=0)
    except _NestedFrameAccess:
        return None
    return accesses


def _walk_collect_frame(
    node: Node,
    stack_chain: str,
    accesses: dict[str, list[JsMemberExpression]],
    depth: int,
) -> None:
    for child in node.children():
        if isinstance(child, (JsFunctionExpression, JsFunctionDeclaration)):
            _walk_collect_frame(child, stack_chain, accesses, depth + 1)
            continue
        if isinstance(child, JsMemberExpression):
            obj = child.object
            if isinstance(obj, JsMemberExpression):
                chain = member_key(obj)
                if chain == stack_chain:
                    key = _extract_access_key(child)
                    if key is not None:
                        if depth > 0:
                            raise _NestedFrameAccess
                        accesses.setdefault(key, []).append(child)
                        continue
        _walk_collect_frame(child, stack_chain, accesses, depth)


def _extract_access_key(node: JsMemberExpression) -> str | None:
    """
    Extract the key from a stack access expression. Returns a string representation of the key
    or None if the key cannot be statically resolved.
    """
    if node.computed:
        prop = node.property
        if isinstance(prop, JsNumericLiteral):
            return str(int(prop.value)) if prop.value == int(prop.value) else None
        if isinstance(prop, JsStringLiteral):
            return prop.value
        if (
            isinstance(prop, JsUnaryExpression)
            and prop.operator == '-'
            and isinstance(prop.operand, JsNumericLiteral)
        ):
            return str(-int(prop.operand.value))
        return None
    if isinstance(node.property, JsIdentifier):
        if node.property.name == 'length':
            return None
        return node.property.name
    return None


def _generate_names(param_count: int, keys: set[str]) -> dict[str, str]:
    """
    Generate fresh identifier names for stack keys. Keys with index 0..N-1 get param names,
    all others get local names.
    """
    mapping: dict[str, str] = {}
    param_idx = 0
    local_idx = 0
    param_keys = set()
    for i in range(param_count):
        param_keys.add(str(i))
    for key in sorted(keys, key=_sort_key):
        if key in param_keys:
            mapping[key] = F'p{param_idx}'
            param_idx += 1
        else:
            mapping[key] = F'v{local_idx}'
            local_idx += 1
    return mapping


def _sort_key(key: str) -> tuple[int, int | str]:
    try:
        n = int(key)
        return (0, n)
    except ValueError:
        return (1, key)


def _remove_truncation(body: JsBlockStatement, rest_name: str, stack_chain: str | None) -> None:
    """
    Remove the `.length = N` truncation statement from the function body.
    """
    stmts = body.body
    for i, stmt in enumerate(stmts):
        if not isinstance(stmt, JsExpressionStatement):
            continue
        expr = stmt.expression
        if not isinstance(expr, JsAssignmentExpression) or expr.operator != '=':
            continue
        lhs = expr.left
        if not isinstance(lhs, JsMemberExpression) or lhs.computed:
            continue
        if not isinstance(lhs.property, JsIdentifier) or lhs.property.name != 'length':
            continue
        obj = lhs.object
        if stack_chain is None:
            if isinstance(obj, JsIdentifier) and obj.name == rest_name:
                stmts.pop(i)
                return
        else:
            if isinstance(obj, JsMemberExpression) and member_key(obj) == stack_chain:
                stmts.pop(i)
                return


class JsRestArrayUnpacking(ScriptLevelTransformer):
    """
    Unpack rest-param arrays back into named identifiers. Detects functions where all parameters
    and locals are packed into a single rest parameter accessed by index, and replaces indexed
    accesses with fresh named variables.
    """

    def _process_script(self, node: JsScript) -> None:
        count = 0
        for fn_node in node.walk():
            if not isinstance(fn_node, (JsFunctionExpression, JsFunctionDeclaration)):
                continue
            if self._demask_function(fn_node):
                count += 1
        if count > 0:
            self.mark_changed()

    def _demask_function(self, fn: JsFunctionExpression | JsFunctionDeclaration) -> bool:
        if len(fn.params) != 1:
            return False
        param = fn.params[0]
        if not isinstance(param, JsRestElement):
            return False
        if not isinstance(param.argument, JsIdentifier):
            return False
        rest_name = param.argument.name
        if fn.body is None or not isinstance(fn.body, JsBlockStatement):
            return False
        if not fn.body.body:
            return False
        result = _extract_truncation(fn.body.body, rest_name)
        if result is None:
            return False
        param_count, stack_chain = result
        if stack_chain is None:
            accesses = _collect_accesses_simple(fn.body, rest_name)
        else:
            accesses = _collect_accesses_frame(fn.body, stack_chain)
        if accesses is None:
            return False
        if param_count > 0 and not any(str(i) in accesses for i in range(param_count)):
            return False
        if not accesses:
            _remove_truncation(fn.body, rest_name, stack_chain)
            fn.params.clear()
            return True
        mapping = _generate_names(param_count, set(accesses.keys()))
        for key, nodes in accesses.items():
            name = mapping[key]
            for access_node in nodes:
                replacement = JsIdentifier(name=name)
                _replace_in_parent(access_node, replacement)
        _remove_truncation(fn.body, rest_name, stack_chain)
        fn.params.clear()
        for i in range(param_count):
            key = str(i)
            name = mapping.get(key, F'p{i}')
            fn.params.append(JsIdentifier(name=name))
        if stack_chain is None:
            self._add_local_declarations(fn.body, mapping, param_count)
        return True

    def _add_local_declarations(
        self,
        body: JsBlockStatement,
        mapping: dict[str, str],
        param_count: int,
    ) -> None:
        """
        Insert `var` declarations for local variables (keys that aren't parameters).
        """
        locals_: list[str] = []
        for key, name in mapping.items():
            try:
                idx = int(key)
                if 0 <= idx < param_count:
                    continue
            except ValueError:
                pass
            locals_.append(name)
        if not locals_:
            return
        declarators = [
            JsVariableDeclarator(id=JsIdentifier(name=n), init=None)
            for n in locals_
        ]
        decl = JsVariableDeclaration(declarations=declarators, kind=JsVarKind.VAR)
        decl.parent = body
        for d in declarators:
            d.parent = decl
            if d.id is not None:
                d.id.parent = d
        body.body.insert(0, decl)

Classes

class JsRestArrayUnpacking

Unpack rest-param arrays back into named identifiers. Detects functions where all parameters and locals are packed into a single rest parameter accessed by index, and replaces indexed accesses with fresh named variables.

Expand source code Browse git
class JsRestArrayUnpacking(ScriptLevelTransformer):
    """
    Unpack rest-param arrays back into named identifiers. Detects functions where all parameters
    and locals are packed into a single rest parameter accessed by index, and replaces indexed
    accesses with fresh named variables.
    """

    def _process_script(self, node: JsScript) -> None:
        count = 0
        for fn_node in node.walk():
            if not isinstance(fn_node, (JsFunctionExpression, JsFunctionDeclaration)):
                continue
            if self._demask_function(fn_node):
                count += 1
        if count > 0:
            self.mark_changed()

    def _demask_function(self, fn: JsFunctionExpression | JsFunctionDeclaration) -> bool:
        if len(fn.params) != 1:
            return False
        param = fn.params[0]
        if not isinstance(param, JsRestElement):
            return False
        if not isinstance(param.argument, JsIdentifier):
            return False
        rest_name = param.argument.name
        if fn.body is None or not isinstance(fn.body, JsBlockStatement):
            return False
        if not fn.body.body:
            return False
        result = _extract_truncation(fn.body.body, rest_name)
        if result is None:
            return False
        param_count, stack_chain = result
        if stack_chain is None:
            accesses = _collect_accesses_simple(fn.body, rest_name)
        else:
            accesses = _collect_accesses_frame(fn.body, stack_chain)
        if accesses is None:
            return False
        if param_count > 0 and not any(str(i) in accesses for i in range(param_count)):
            return False
        if not accesses:
            _remove_truncation(fn.body, rest_name, stack_chain)
            fn.params.clear()
            return True
        mapping = _generate_names(param_count, set(accesses.keys()))
        for key, nodes in accesses.items():
            name = mapping[key]
            for access_node in nodes:
                replacement = JsIdentifier(name=name)
                _replace_in_parent(access_node, replacement)
        _remove_truncation(fn.body, rest_name, stack_chain)
        fn.params.clear()
        for i in range(param_count):
            key = str(i)
            name = mapping.get(key, F'p{i}')
            fn.params.append(JsIdentifier(name=name))
        if stack_chain is None:
            self._add_local_declarations(fn.body, mapping, param_count)
        return True

    def _add_local_declarations(
        self,
        body: JsBlockStatement,
        mapping: dict[str, str],
        param_count: int,
    ) -> None:
        """
        Insert `var` declarations for local variables (keys that aren't parameters).
        """
        locals_: list[str] = []
        for key, name in mapping.items():
            try:
                idx = int(key)
                if 0 <= idx < param_count:
                    continue
            except ValueError:
                pass
            locals_.append(name)
        if not locals_:
            return
        declarators = [
            JsVariableDeclarator(id=JsIdentifier(name=n), init=None)
            for n in locals_
        ]
        decl = JsVariableDeclaration(declarations=declarators, kind=JsVarKind.VAR)
        decl.parent = body
        for d in declarators:
            d.parent = decl
            if d.id is not None:
                d.id.parent = d
        body.body.insert(0, decl)

Ancestors