Module refinery.lib.scripts.js.deobfuscation.cff.statemachine

Recover original code from generator-based state-machine CFF dispatchers.

Handles the pattern where a function body is replaced with a generator function containing a while/switch state machine driven by multiple state variables whose sum is the switch discriminant. Each case updates the state via relative += assignments.

Expand source code Browse git
"""
Recover original code from generator-based state-machine CFF dispatchers.

Handles the pattern where a function body is replaced with a generator function containing a
while/switch state machine driven by multiple state variables whose sum is the switch
discriminant. Each case updates the state via relative `+=` assignments.
"""
from __future__ import annotations

from collections import deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple

from refinery.lib.scripts import Expression, Node, Statement, _clone_node, _replace_in_parent
from refinery.lib.scripts.js.deobfuscation.helpers import (
    BodyProcessingTransformer,
    access_key,
    eval_binary_op,
    make_numeric_literal,
    member_key,
    property_key,
)
from refinery.lib.scripts.js.model import (
    JsArrayExpression,
    JsArrayPattern,
    JsAssignmentExpression,
    JsAssignmentPattern,
    JsBinaryExpression,
    JsBlockStatement,
    JsBooleanLiteral,
    JsBreakStatement,
    JsCallExpression,
    JsCatchClause,
    JsContinueStatement,
    JsExpressionStatement,
    JsFunctionDeclaration,
    JsFunctionExpression,
    JsIdentifier,
    JsIfStatement,
    JsLabeledStatement,
    JsLogicalExpression,
    JsMemberExpression,
    JsNumericLiteral,
    JsObjectExpression,
    JsObjectPattern,
    JsProperty,
    JsRestElement,
    JsReturnStatement,
    JsScript,
    JsSequenceExpression,
    JsStringLiteral,
    JsSwitchCase,
    JsSwitchStatement,
    JsUnaryExpression,
    JsVariableDeclaration,
    JsVariableDeclarator,
    JsVarKind,
    JsWhileStatement,
    JsWithStatement,
)

if TYPE_CHECKING:
    _StateEnv = dict[str, int | float]

_MAX_STEPS = 2000


class _CallSiteInfo(NamedTuple):
    initial_state: list[int | float]
    did_return_var: str | None
    result_var: str | None
    scaffolding_end: int


class _WrapperFunctionInfo(NamedTuple):
    initial_state: list[int | float]
    rest_param_name: str | None
    scope_arg: JsObjectExpression | None


@dataclass
class _SMRawAssignment:
    """
    A single unevaluated state variable assignment: `name op= rhs`.
    """
    name: str
    operator: str
    rhs: Expression


@dataclass
class _SMLinearTransition:
    """
    Unconditional state transition: a sequence of `+=` or `=` assignments to state variables.
    """
    assignments: list[_SMRawAssignment]


@dataclass
class _SMConditionalTransition:
    """
    Conditional state transition: an if/else where each branch sets different state values.
    """
    condition: Expression
    true_assignments: list[_SMRawAssignment]
    false_assignments: list[_SMRawAssignment]
    true_prefix: list[Statement] = field(default_factory=list)
    false_prefix: list[Statement] = field(default_factory=list)


@dataclass
class _SMExitTransition:
    """
    The state machine reaches the end state after this block.
    """
    pass


if TYPE_CHECKING:
    _SMTransition = _SMLinearTransition | _SMConditionalTransition | _SMExitTransition


@dataclass
class _SMBlock:
    """
    A single state in the machine: payload statements plus a transition.
    """
    state_id: int | float
    payload: list[Statement]
    transition: _SMTransition


@dataclass
class _GeneratorCFFMatch:
    """
    Structural match result for a generator-based state-machine CFF pattern.
    """
    generator_name: str
    state_var_names: list[str]
    initial_state: list[int | float]
    end_state: int | float
    switch_stmt: JsSwitchStatement
    switch_label: str | None
    scope_param_name: str | None
    arg_var_name: str | None
    did_return_var: str | None
    result_var: str | None
    gen_decl_index: int
    scaffolding_end: int
    with_redirect_var: str | None = None
    scope_default_props: list[str] = field(default_factory=list)
    scope_default_inits: dict[str, Expression] = field(default_factory=dict)
    arg_params: list[str] = field(default_factory=list)


def _eval_expr(node: Expression, env: _StateEnv) -> int | float | None:
    """
    Recursively evaluate an arithmetic expression against a variable environment. Returns `None`
    when the expression cannot be resolved.
    """
    if isinstance(node, JsNumericLiteral):
        return node.value
    if isinstance(node, JsIdentifier):
        return env.get(node.name)
    if isinstance(node, JsMemberExpression):
        key = member_key(node)
        if key is not None:
            return env.get(key)
        return None
    if isinstance(node, JsUnaryExpression) and node.prefix and node.operand is not None:
        if node.operator == '-':
            inner = _eval_expr(node.operand, env)
            return -inner if inner is not None else None
        if node.operator == '+':
            return _eval_expr(node.operand, env)
    if isinstance(node, JsLogicalExpression) and node.left is not None and node.right is not None:
        if node.operator == '&&':
            lhs = _eval_expr(node.left, env)
            if lhs is None:
                return None
            if not lhs:
                return lhs
            return _eval_expr(node.right, env)
        if node.operator == '||':
            lhs = _eval_expr(node.left, env)
            if lhs is None:
                return None
            if lhs:
                return lhs
            return _eval_expr(node.right, env)
    if isinstance(node, JsBinaryExpression) and node.left is not None and node.right is not None:
        lhs = _eval_expr(node.left, env)
        rhs = _eval_expr(node.right, env)
        if lhs is None or rhs is None:
            return None
        result = eval_binary_op(node.operator, lhs, rhs)
        if result is None:
            return None
        if isinstance(result, bool):
            return int(result)
        if isinstance(result, float):
            try:
                int_val = int(result)
            except (OverflowError, ValueError):
                return None
            if result == int_val:
                return int_val
        return result
    return None


def _is_discriminant_sum(node: Expression, var_names: list[str]) -> bool:
    """
    Check whether an expression is the sum of the given state variable identifiers.
    """
    collected: list[str] = []
    _collect_sum_idents(node, collected)
    return sorted(collected) == sorted(var_names)


def _collect_sum_idents(node: Expression, out: list[str]) -> bool:
    if isinstance(node, JsIdentifier):
        out.append(node.name)
        return True
    if isinstance(node, JsBinaryExpression) and node.operator == '+':
        if node.left is not None and node.right is not None:
            return _collect_sum_idents(node.left, out) and _collect_sum_idents(node.right, out)
    return False


def _extract_with_redirect_var(
    with_obj: Expression | None,
    scope_param_name: str | None,
) -> str | None:
    """
    Parse the `with(scope.W || scope)` pattern to extract the redirect property name `W`.
    """
    if scope_param_name is None or with_obj is None:
        return None
    if not isinstance(with_obj, JsLogicalExpression) or with_obj.operator != '||':
        return None
    lhs = with_obj.left
    rhs = with_obj.right
    if not isinstance(rhs, JsIdentifier) or rhs.name != scope_param_name:
        return None
    if not isinstance(lhs, JsMemberExpression):
        return None
    if not isinstance(lhs.object, JsIdentifier) or lhs.object.name != scope_param_name:
        return None
    if lhs.computed:
        if not isinstance(lhs.property, JsStringLiteral):
            return None
        return lhs.property.value
    if not isinstance(lhs.property, JsIdentifier):
        return None
    return lhs.property.name


class _ScopeDefaults(NamedTuple):
    prop_names: list[str]
    initializers: dict[str, Expression]


def _extract_scope_default_props(
    params: list, scope_param_name: str | None,
) -> _ScopeDefaults:
    """
    Extract namespace property names and their initializer expressions from the scope parameter's
    default value. For the pattern `scope = { MpAqdCF: {} }` this returns a tuple of the form

        (['MpAqdCF'], {'MpAqdCF': <JsObjectExpression>})
    """
    if scope_param_name is None:
        return _ScopeDefaults([], {})
    for p in params:
        if not isinstance(p, JsAssignmentPattern):
            continue
        if not isinstance(p.left, JsIdentifier) or p.left.name != scope_param_name:
            continue
        if not isinstance(p.right, JsObjectExpression):
            return _ScopeDefaults([], {})
        names: list[str] = []
        inits: dict[str, Expression] = {}
        for prop in p.right.properties:
            if not isinstance(prop, JsProperty):
                continue
            key = property_key(prop)
            if key is not None:
                names.append(key)
                if prop.value is not None:
                    inits[key] = prop.value
        return _ScopeDefaults(names, inits)
    return _ScopeDefaults([], {})


def _match_generator_cff(body: list[Statement], idx: int) -> _GeneratorCFFMatch | None:
    """
    Starting at index *idx* in *body*, test whether the statement is a generator function
    declaration matching the state machine CFF pattern, with its call site following.
    """
    stmt = body[idx]
    if not isinstance(stmt, JsFunctionDeclaration):
        return None
    if not stmt.generator:
        return None
    if stmt.id is None:
        return None
    gen_name = stmt.id.name
    if stmt.body is None:
        return None
    params = stmt.params
    if len(params) < 3:
        return None
    scope_param_name: str | None = None
    arg_var_name: str | None = None
    state_var_names: list[str] = []
    for p in params:
        if isinstance(p, JsIdentifier):
            if scope_param_name is not None:
                arg_var_name = p.name
                break
            state_var_names.append(p.name)
        elif isinstance(p, JsAssignmentPattern) and isinstance(p.left, JsIdentifier):
            scope_param_name = p.left.name
        else:
            return None
    if not state_var_names:
        return None
    gen_body = stmt.body.body
    if len(gen_body) != 1:
        return None
    while_stmt = gen_body[0]
    if not isinstance(while_stmt, JsWhileStatement):
        return None
    if while_stmt.test is None or while_stmt.body is None:
        return None
    if not isinstance(while_stmt.test, JsBinaryExpression):
        return None
    if while_stmt.test.operator != '!==':
        return None
    lhs = while_stmt.test.left
    rhs = while_stmt.test.right
    if lhs is None or rhs is None:
        return None
    end_state: int | float | None = None
    if _is_discriminant_sum(lhs, state_var_names):
        end_state = _eval_expr(rhs, {})
    elif _is_discriminant_sum(rhs, state_var_names):
        end_state = _eval_expr(lhs, {})
    if end_state is None:
        return None
    inner: Statement | None = while_stmt.body
    if isinstance(inner, JsBlockStatement) and len(inner.body) == 1:
        inner = inner.body[0]
    with_redirect_var: str | None = None
    scope_default_props: list[str] = []
    scope_default_inits: dict[str, Expression] = {}
    if isinstance(inner, JsWithStatement):
        with_redirect_var = _extract_with_redirect_var(inner.object, scope_param_name)
        scope_default_props, scope_default_inits = _extract_scope_default_props(params, scope_param_name)
        inner = inner.body
    if isinstance(inner, JsBlockStatement) and len(inner.body) == 1:
        inner = inner.body[0]
    switch_label: str | None = None
    if isinstance(inner, JsLabeledStatement):
        if inner.label is not None:
            switch_label = inner.label.name
        inner = inner.body
    if not isinstance(inner, JsSwitchStatement):
        return None
    if inner.discriminant is None:
        return None
    if not _is_discriminant_sum(inner.discriminant, state_var_names):
        return None
    switch_stmt = inner
    call_info = _find_generator_call_site(body, idx, gen_name)
    if call_info is None:
        return None
    if len(call_info.initial_state) != len(state_var_names):
        return None
    return _GeneratorCFFMatch(
        generator_name=gen_name,
        state_var_names=state_var_names,
        initial_state=call_info.initial_state,
        end_state=end_state,
        switch_stmt=switch_stmt,
        switch_label=switch_label,
        scope_param_name=scope_param_name,
        arg_var_name=arg_var_name,
        did_return_var=call_info.did_return_var,
        result_var=call_info.result_var,
        gen_decl_index=idx,
        scaffolding_end=call_info.scaffolding_end,
        with_redirect_var=with_redirect_var,
        scope_default_props=scope_default_props,
        scope_default_inits=scope_default_inits,
    )


def _find_generator_call_site(
    body: list[Statement],
    gen_idx: int,
    gen_name: str,
) -> _CallSiteInfo | None:
    """
    Scan forward from *gen_idx* to find the call site pattern:

        var didReturn;
        var result = genName(args)["next"]()["value"];
        if (didReturn) { return result; }

    Returns a `_CallSiteInfo` or `None`.
    Skips over intervening declarations (function declarations, other var decls) that are not
    part of the generator scaffolding.
    """
    pos = gen_idx + 1
    did_return_var: str | None = None
    result_var: str | None = None
    while pos < len(body):
        candidate = body[pos]
        if isinstance(candidate, JsVariableDeclaration):
            decls = candidate.declarations
            if (
                len(decls) == 1
                and isinstance(decls[0], JsVariableDeclarator)
                and isinstance(decls[0].id, JsIdentifier)
                and decls[0].init is None
            ):
                did_return_var = decls[0].id.name
                pos += 1
                continue
        if isinstance(candidate, JsFunctionDeclaration):
            pos += 1
            continue
        if isinstance(candidate, JsExpressionStatement):
            expr = candidate.expression
            if (
                isinstance(expr, JsAssignmentExpression)
                and expr.operator == '='
                and isinstance(expr.left, JsIdentifier)
                and isinstance(expr.right, JsUnaryExpression)
                and expr.right.operator == 'void'
            ):
                did_return_var = expr.left.name
                pos += 1
                continue
        break
    if pos >= len(body):
        return None
    call_expr = _extract_generator_call(body[pos], gen_name)
    if call_expr is None:
        return None
    call_node, result_var = call_expr
    initial_state: list[int | float] = []
    for arg in call_node.arguments:
        val = _eval_expr(arg, {})
        if val is None:
            return None
        initial_state.append(val)
    scaffolding_end = pos
    if scaffolding_end + 1 < len(body) and did_return_var is not None:
        guard = body[scaffolding_end + 1]
        if (
            isinstance(guard, JsIfStatement)
            and isinstance(guard.test, JsIdentifier)
            and guard.test.name == did_return_var
        ):
            scaffolding_end += 1
    return _CallSiteInfo(initial_state, did_return_var, result_var, scaffolding_end)


class _GeneratorCallInfo(NamedTuple):
    call_node: JsCallExpression
    result_var: str | None


def _extract_generator_call(
    stmt: Statement,
    gen_name: str,
) -> _GeneratorCallInfo | None:
    """
    Extract a generator call from a statement. Handles:
    - var X = gen(...)["next"]()["value"];
    - gen(...)["next"]()["value"];
    - return gen(...)["next"]()["value"];

    Returns a `(call, name)` pair (the inner call to gen and the result variable name) or `None`.
    """
    result_name: str | None = None
    expr: Expression | None = None
    if isinstance(stmt, JsVariableDeclaration):
        if len(stmt.declarations) != 1:
            return None
        decl = stmt.declarations[0]
        if not isinstance(decl, JsVariableDeclarator):
            return None
        if isinstance(decl.id, JsIdentifier):
            result_name = decl.id.name
        expr = decl.init
    elif isinstance(stmt, JsExpressionStatement):
        expr = stmt.expression
        if isinstance(expr, JsAssignmentExpression) and expr.operator == '=':
            if isinstance(expr.left, JsIdentifier):
                result_name = expr.left.name
            expr = expr.right
    elif isinstance(stmt, JsReturnStatement):
        expr = stmt.argument
    else:
        return None
    if expr is None:
        return None
    gen_call = _unwrap_next_value(expr)
    if gen_call is None:
        if isinstance(expr, JsCallExpression):
            gen_call = expr
        else:
            return None
    if not isinstance(gen_call, JsCallExpression):
        return None
    if not isinstance(gen_call.callee, JsIdentifier):
        return None
    if gen_call.callee.name != gen_name:
        return None
    return _GeneratorCallInfo(gen_call, result_name)


def _unwrap_next_value(node: Expression) -> JsCallExpression | None:
    """
    Unwrap the `gen(...)` call from `gen(...).next().value`. Works when `next` and `value` are
    accessed as properties or as keys.
    """
    if not isinstance(node, JsMemberExpression):
        return None
    key = access_key(node)
    if key != 'value':
        return None
    next_call = node.object
    if not isinstance(next_call, JsCallExpression) or next_call.arguments:
        return None
    next_access = next_call.callee
    if not isinstance(next_access, JsMemberExpression):
        return None
    if access_key(next_access) != 'next':
        return None
    gen_call = next_access.object
    if not isinstance(gen_call, JsCallExpression):
        return None
    return gen_call


def _detect_wrapper_function(
    node: Expression,
    gen_name: str,
    num_state_vars: int,
) -> _WrapperFunctionInfo | None:
    """
    Test whether *node* is a wrapper function expression of the form:

        function(...rest) { return gen(states..., scope, rest)["next"]()["value"]; }

    Returns a `_WrapperFunctionInfo` or `None`.
    """
    if not isinstance(node, JsFunctionExpression):
        return None
    if node.body is None:
        return None
    body = node.body.body
    if len(body) != 1:
        return None
    stmt = body[0]
    if not isinstance(stmt, JsReturnStatement) or stmt.argument is None:
        return None
    gen_call = _unwrap_next_value(stmt.argument)
    if gen_call is None:
        return None
    if not isinstance(gen_call.callee, JsIdentifier):
        return None
    if gen_call.callee.name != gen_name:
        return None
    args = gen_call.arguments
    if len(args) < num_state_vars + 1:
        return None
    initial_state: list[int | float] = []
    for arg in args[:num_state_vars]:
        val = _eval_expr(arg, {})
        if val is None:
            return None
        initial_state.append(val)
    scope_arg: JsObjectExpression | None = None
    if len(args) > num_state_vars:
        candidate = args[num_state_vars]
        if isinstance(candidate, JsObjectExpression):
            scope_arg = candidate
    rest_param_name: str | None = None
    params = node.params
    if params:
        last_param = params[-1]
        if isinstance(last_param, JsRestElement) and isinstance(last_param.argument, JsIdentifier):
            rest_param_name = last_param.argument.name
        elif isinstance(last_param, JsIdentifier):
            rest_param_name = last_param.name
    return _WrapperFunctionInfo(initial_state, rest_param_name, scope_arg)


@dataclass
class _StateMachine:
    """
    Complete parsed state machine with both statically-resolved and predicate-gated cases.
    """
    blocks: dict[int | float, _SMBlock]
    predicate_cases: list[tuple[Expression, _SMBlock]]
    default_block: _SMBlock | None = None


def _extract_state_blocks(
    match: _GeneratorCFFMatch,
) -> _StateMachine | None:
    """
    Parse the switch cases into a state machine. Cases with statically resolvable tests go into
    `blocks`; those with predicate tests (referencing state vars) go into `predicate_cases` for
    runtime resolution. A `default:` case becomes the fallback block.
    """
    var_names = match.state_var_names
    label = match.switch_label
    end_state = match.end_state
    blocks: dict[int | float, _SMBlock] = {}
    predicate_cases: list[tuple[Expression, _SMBlock]] = []
    default_block: _SMBlock | None = None
    pending_tests: list[JsSwitchCase] = []
    for case in match.switch_stmt.cases:
        if not isinstance(case, JsSwitchCase):
            return None
        if not case.body:
            pending_tests.append(case)
            continue
        all_cases = list(pending_tests) + [case]
        pending_tests.clear()
        stmts = list(case.body)
        parsed = _parse_case_body(stmts, var_names, label)
        if parsed is None:
            continue
        payload, transition = parsed
        has_default = any(c.test is None for c in all_cases)
        resolved = False
        block_obj = _SMBlock(state_id=0, payload=payload, transition=transition)
        for c in all_cases:
            if c.test is None:
                continue
            val = _eval_expr(c.test, {})
            if val is not None:
                if val != end_state and val not in blocks:
                    if block_obj.state_id == 0:
                        block_obj.state_id = val
                    blocks[val] = block_obj
                resolved = True
            else:
                predicate_cases.append((c.test, block_obj))
        if has_default:
            default_block = block_obj
        if not resolved and not has_default and not any(
            c.test is not None and _eval_expr(c.test, {}) is None for c in all_cases
        ):
            continue
    if not blocks and not predicate_cases and default_block is None:
        return None
    return _StateMachine(blocks=blocks, predicate_cases=predicate_cases, default_block=default_block)


def _parse_case_body(
    stmts: list[Statement],
    var_names: list[str],
    switch_label: str | None,
) -> tuple[list[Statement], _SMTransition] | None:
    """
    Separate a case body into payload statements and a state transition.
    """
    if not stmts:
        return None
    stmts = _strip_trailing_labeled_break(stmts, switch_label)
    if not stmts:
        return ([], _SMExitTransition())
    last = stmts[-1]
    if isinstance(last, JsExpressionStatement) and isinstance(last.expression, JsSequenceExpression):
        assignments = _extract_state_assignments(last.expression, var_names)
        if assignments is not None:
            non_state = _extract_non_state_expressions(last.expression, var_names)
            payload = list(stmts[:-1])
            if non_state:
                payload.append(JsExpressionStatement(expression=non_state))
            return (payload, _SMLinearTransition(assignments=assignments))
    trailing = _collect_trailing_state_assignments(stmts, var_names)
    if trailing is not None:
        assignments, split_idx = trailing
        return (stmts[:split_idx], _SMLinearTransition(assignments=assignments))
    if isinstance(last, JsIfStatement) and last.consequent is not None and last.alternate is not None:
        cond_result = _parse_conditional_transition(last, var_names, switch_label)
        if cond_result is not None:
            payload = stmts[:-1]
            return (payload, cond_result)
    if isinstance(last, JsReturnStatement):
        return (stmts, _SMExitTransition())
    return None


def _strip_trailing_labeled_break(stmts: list[Statement], label: str | None) -> list[Statement]:
    """
    Remove a trailing `break label;` that targets the switch label.
    """
    if not stmts:
        return stmts
    last = stmts[-1]
    if isinstance(last, JsBreakStatement):
        if last.label is None or (label is not None and last.label.name == label):
            return stmts[:-1]
    return stmts


def _extract_state_assignments(
    seq: JsSequenceExpression,
    var_names: list[str],
) -> list[_SMRawAssignment] | None:
    """
    Extract state variable assignments from a sequence expression without evaluating them.
    Non-state assignments (scope/with updates) are skipped.
    """
    result: list[_SMRawAssignment] = []
    for expr in seq.expressions:
        if not isinstance(expr, JsAssignmentExpression):
            continue
        if not isinstance(expr.left, JsIdentifier):
            continue
        name = expr.left.name
        if name not in var_names:
            continue
        if expr.right is None:
            return None
        if expr.operator not in ('=', '+='):
            return None
        result.append(_SMRawAssignment(name=name, operator=expr.operator, rhs=expr.right))
    if not result:
        return None
    return result


def _extract_non_state_expressions(
    seq: JsSequenceExpression,
    var_names: list[str],
) -> Expression | None:
    """
    Collect non-state-variable expressions from a sequence. Returns a single expression (or
    sequence expression) for the payload, or None if all expressions are state assignments.
    """
    remaining: list[Expression] = []
    for expr in seq.expressions:
        if isinstance(expr, JsAssignmentExpression) and isinstance(expr.left, JsIdentifier):
            if expr.left.name in var_names:
                continue
        remaining.append(expr)
    if not remaining:
        return None
    if len(remaining) == 1:
        return remaining[0]
    return JsSequenceExpression(expressions=remaining)


def _collect_trailing_state_assignments(
    stmts: list[Statement],
    var_names: list[str],
) -> tuple[list[_SMRawAssignment], int] | None:
    """
    Scan backwards from the end of the statement list to collect all consecutive state-variable
    assignment statements. Returns the collected assignments and the split index (where payload
    ends), or None if no trailing state assignments found.
    """
    assignments: list[_SMRawAssignment] = []
    i = len(stmts) - 1
    while i >= 0:
        stmt = stmts[i]
        if not isinstance(stmt, JsExpressionStatement):
            break
        if not isinstance(stmt.expression, JsAssignmentExpression):
            break
        expr = stmt.expression
        if not isinstance(expr.left, JsIdentifier):
            break
        if expr.left.name not in var_names:
            break
        if expr.right is None:
            break
        if expr.operator not in ('=', '+='):
            break
        assignments.append(_SMRawAssignment(name=expr.left.name, operator=expr.operator, rhs=expr.right))
        i -= 1
    if not assignments:
        return None
    assignments.reverse()
    return (assignments, i + 1)


def _extract_single_assignment(
    expr: JsAssignmentExpression,
    var_names: list[str],
) -> list[_SMRawAssignment] | None:
    """
    Extract a single state variable assignment.
    """
    if not isinstance(expr.left, JsIdentifier):
        return None
    name = expr.left.name
    if name not in var_names:
        return None
    if expr.right is None:
        return None
    if expr.operator not in ('=', '+='):
        return None
    return [_SMRawAssignment(name=name, operator=expr.operator, rhs=expr.right)]


def _apply_raw_transition(
    assignments: list[_SMRawAssignment],
    current: _StateEnv,
) -> _StateEnv | None:
    """
    Evaluate raw assignments against the current state to produce the new state.
    Left-to-right sequential semantics: each assignment sees the results of prior ones.
    """
    env: _StateEnv = dict(current)
    for assign in assignments:
        val = _eval_expr(assign.rhs, env)
        if val is None:
            return None
        if assign.operator == '+=':
            env[assign.name] = env.get(assign.name, 0) + val
        else:
            env[assign.name] = val
    return env


def _block_stmts(node: Statement) -> list[Statement] | None:
    if isinstance(node, JsBlockStatement):
        return list(node.body)
    return [node]


def _extract_trailing_assignments(
    stmts: list[Statement],
    var_names: list[str],
) -> tuple[list[_SMRawAssignment], list[Statement]] | None:
    """
    Extract the trailing state assignment from a list of statements and return
    (raw_assignments, prefix_statements). For mixed sequence expressions, non-state expressions
    are preserved in the prefix.
    """
    if not stmts:
        return None
    last = stmts[-1]
    if isinstance(last, JsExpressionStatement):
        if isinstance(last.expression, JsSequenceExpression):
            assigns = _extract_state_assignments(last.expression, var_names)
            if assigns is not None:
                non_state = _extract_non_state_expressions(last.expression, var_names)
                prefix = list(stmts[:-1])
                if non_state:
                    prefix.append(JsExpressionStatement(expression=non_state))
                return (assigns, prefix)
        elif isinstance(last.expression, JsAssignmentExpression):
            assigns = _extract_single_assignment(last.expression, var_names)
            if assigns is not None:
                return (assigns, stmts[:-1])
    return None


def _parse_conditional_transition(
    if_stmt: JsIfStatement,
    var_names: list[str],
    switch_label: str | None,
) -> _SMConditionalTransition | None:
    """
    Parse an if/else whose branches both perform state transitions.
    """
    if if_stmt.test is None:
        return None
    true_block = if_stmt.consequent
    false_block = if_stmt.alternate
    if true_block is None or false_block is None:
        return None
    true_stmts = _block_stmts(true_block)
    false_stmts = _block_stmts(false_block)
    if true_stmts is None or false_stmts is None:
        return None
    true_stmts = _strip_trailing_labeled_break(true_stmts, switch_label)
    false_stmts = _strip_trailing_labeled_break(false_stmts, switch_label)
    true_state = _extract_trailing_assignments(true_stmts, var_names)
    false_state = _extract_trailing_assignments(false_stmts, var_names)
    if true_state is None or false_state is None:
        return None
    true_assigns, true_prefix = true_state
    false_assigns, false_prefix = false_state
    return _SMConditionalTransition(
        condition=if_stmt.test,
        true_assignments=true_assigns,
        false_assignments=false_assigns,
        true_prefix=true_prefix,
        false_prefix=false_prefix,
    )


def _compute_discriminant(state: _StateEnv, var_names: list[str]) -> int | float:
    return sum(state.get(n, 0) for n in var_names)


def _lookup_block(machine: _StateMachine, disc: int | float, state: _StateEnv) -> _SMBlock | None:
    """
    Find the block matching the given discriminant. Tries static blocks first, then evaluates
    predicate tests against the current state, then falls back to the default block.
    """
    if disc in machine.blocks:
        return machine.blocks[disc]
    for test_expr, block in machine.predicate_cases:
        val = _eval_expr(test_expr, state)
        if val is not None and val == disc:
            return block
    return machine.default_block


def _apply_initial_state(var_names: list[str], values: list[int | float]) -> _StateEnv:
    return dict(zip(var_names, values))


def _is_state_var_assignment(expr: Expression, var_set: set[str]) -> bool:
    return (
        isinstance(expr, JsAssignmentExpression)
        and isinstance(expr.left, JsIdentifier)
        and expr.left.name in var_set
    )


def _apply_prefix_state_changes(
    prefix: list[Statement],
    var_names: list[str],
    env: _StateEnv,
) -> _StateEnv:
    """
    Scan prefix statements for assignments to state variables and apply them sequentially. This
    handles cases where a conditional's prefix modifies state variables before the trailing
    transition assignment.
    """
    result = dict(env)
    var_set = set(var_names)
    for stmt in prefix:
        if not isinstance(stmt, JsExpressionStatement):
            continue
        expr = stmt.expression
        exprs = expr.expressions if isinstance(expr, JsSequenceExpression) else [expr]
        for e in exprs:
            if not isinstance(e, JsAssignmentExpression):
                continue
            if not isinstance(e.left, JsIdentifier):
                continue
            if e.left.name not in var_set:
                continue
            if e.right is None:
                continue
            rhs_val = _eval_expr(e.right, result)
            if rhs_val is None:
                continue
            name = e.left.name
            if e.operator == '=':
                result[name] = rhs_val
            elif e.operator == '+=':
                result[name] = result.get(name, 0) + rhs_val
            elif e.operator == '-=':
                result[name] = result.get(name, 0) - rhs_val
    return result


def _strip_state_var_assignments(stmts: list[Statement], var_names: list[str]) -> list[Statement]:
    """
    Remove statements that are pure assignments to state variables. These are routing bookkeeping
    that should not appear in the recovered output. For sequence expressions, state var assignments
    are removed while preserving remaining payload expressions.
    """
    var_set = set(var_names)
    result: list[Statement] = []
    for stmt in stmts:
        if not isinstance(stmt, JsExpressionStatement):
            result.append(stmt)
            continue
        expr = stmt.expression
        if expr is None:
            result.append(stmt)
            continue
        if isinstance(expr, JsSequenceExpression):
            remaining = [e for e in expr.expressions if not _is_state_var_assignment(e, var_set)]
            if not remaining:
                continue
            if len(remaining) == 1:
                result.append(JsExpressionStatement(expression=remaining[0]))
            else:
                result.append(JsExpressionStatement(
                    expression=JsSequenceExpression(expressions=remaining),
                ))
        elif _is_state_var_assignment(expr, var_set):
            continue
        else:
            result.append(stmt)
    return result


def _process_branch_prefix(
    prefix: list[Statement],
    var_names: list[str],
    state: _StateEnv,
    match: _GeneratorCFFMatch,
    strip_ns: str | None,
    redirect_target: str | None,
) -> list[Statement]:
    """
    Process a conditional transition's branch prefix through the standard pipeline (strip state
    vars, substitute, filter bookkeeping, strip scope, qualify). Returns the processed statements
    ready for emission as branch-specific payload.
    """
    result = _strip_state_var_assignments(prefix, var_names)
    if not result:
        return []
    result = _substitute_state_vars(result, state)
    if not match.arg_params:
        for s in result:
            params = _extract_arg_param_names(s, match.arg_var_name)
            if params is not None:
                match.arg_params = params
                break
    result = [
        _strip_scope_bookkeeping_from_sequence(s, match.scope_param_name)
        for s in result
        if not _is_scope_bookkeeping(s, match.scope_param_name, match.arg_var_name)
    ]
    result = _strip_scope_param_prefix(result, match.scope_param_name, strip_ns)
    result = _qualify_with_identifiers(result, match, redirect_target)
    result = _filter_redirect_var_assignments(result, match)
    return result


_VIRTUAL_EXIT: int = -1


@dataclass
class _CFGNode:
    """
    A node in the control flow graph derived from the state machine. Keyed by block object
    identity (`id(block)`) so that the same logical block visited with different discriminants
    is recognized as a single CFG node — enabling loop detection.
    """
    node_id: int
    payload: list[Statement]
    condition: Expression | None
    successors: list[int] = field(default_factory=list)
    predecessors: list[int] = field(default_factory=list)
    true_prefix_payload: list[Statement] = field(default_factory=list)
    false_prefix_payload: list[Statement] = field(default_factory=list)


@dataclass
class _CFG:
    """
    Control flow graph built from symbolic execution of state machine transitions.
    """
    nodes: dict[int, _CFGNode]
    entry: int
    exit: int


@dataclass
class _NaturalLoop:
    """
    A natural loop identified by a back-edge in the CFG.
    """
    header: int
    body: set[int]
    tails: list[int]
    exits: set[int]


def _build_cfg(
    machine: _StateMachine,
    initial_state: _StateEnv,
    var_names: list[str],
    end_state: int | float,
    match: _GeneratorCFFMatch,
) -> tuple[_CFG, _StateEnv] | None:
    """
    Build a control flow graph by BFS from the initial state. Nodes are keyed by the identity
    of the `_SMBlock` object they correspond to, so the same block reached with different
    discriminants (as happens in loops with relative `+=` transitions) creates a single node
    with a back-edge. Returns the CFG and the accumulated state (including scope routing values).
    """
    entry_state = dict(initial_state)
    entry_disc = _compute_discriminant(entry_state, var_names)
    entry_block = _lookup_block(machine, entry_disc, entry_state)
    if entry_block is None:
        return None

    nodes: dict[int, _CFGNode] = {}
    routing_state: _StateEnv = dict(initial_state)
    node_envs: dict[int, _StateEnv] = {}
    queue: deque[tuple[_SMBlock, _StateEnv, str | None]] = deque()
    queue.append((entry_block, entry_state, None))
    steps = 0

    while queue and steps < _MAX_STEPS:
        steps += 1
        block, state, redirect_target = queue.popleft()
        node_id = id(block)

        if node_id in nodes:
            continue

        payload = _substitute_state_vars(block.payload, state)
        _track_scope_routing(payload, state)
        _track_scope_routing(payload, routing_state)
        new_redirect = _extract_redirect_target(
            payload, match.scope_param_name, match.with_redirect_var,
        )
        if not match.arg_params:
            for s in payload:
                params = _extract_arg_param_names(s, match.arg_var_name)
                if params is not None:
                    match.arg_params = params
                    break
        payload = [
            _strip_scope_bookkeeping_from_sequence(s, match.scope_param_name)
            for s in payload
            if not _is_scope_bookkeeping(s, match.scope_param_name, match.arg_var_name)
        ]
        strip_ns = match.scope_default_props[0] if redirect_target and match.scope_default_props else None
        payload = _strip_scope_param_prefix(payload, match.scope_param_name, strip_ns)
        payload = _qualify_with_identifiers(payload, match, redirect_target)
        payload = _filter_redirect_var_assignments(payload, match)
        next_redirect = new_redirect if new_redirect is not None else redirect_target

        condition: Expression | None = None
        successors: list[int] = []
        true_prefix_payload: list[Statement] = []
        false_prefix_payload: list[Statement] = []
        transition = block.transition

        if isinstance(transition, _SMExitTransition):
            successors = [_VIRTUAL_EXIT]
        elif isinstance(transition, _SMLinearTransition):
            new_env = _apply_raw_transition(transition.assignments, state)
            if new_env is None:
                return None
            next_disc = _compute_discriminant(new_env, var_names)
            if next_disc == end_state:
                successors = [_VIRTUAL_EXIT]
            else:
                next_block = _lookup_block(machine, next_disc, new_env)
                if next_block is None:
                    return None
                next_id = id(next_block)
                successors = [next_id]
                if next_id not in nodes:
                    queue.append((next_block, new_env, next_redirect))
        elif isinstance(transition, _SMConditionalTransition):
            condition = transition.condition
            true_base = _apply_prefix_state_changes(transition.true_prefix, var_names, state)
            false_base = _apply_prefix_state_changes(transition.false_prefix, var_names, state)
            true_env = _apply_raw_transition(transition.true_assignments, true_base)
            false_env = _apply_raw_transition(transition.false_assignments, false_base)
            if true_env is None or false_env is None:
                return None
            true_disc = _compute_discriminant(true_env, var_names)
            false_disc = _compute_discriminant(false_env, var_names)

            true_redirect = (
                _extract_redirect_target(
                    transition.true_prefix, match.scope_param_name, match.with_redirect_var,
                ) or next_redirect
            )
            false_redirect = (
                _extract_redirect_target(
                    transition.false_prefix, match.scope_param_name, match.with_redirect_var,
                ) or next_redirect
            )

            if true_disc == end_state:
                true_id = _VIRTUAL_EXIT
            else:
                true_block = _lookup_block(machine, true_disc, true_env)
                if true_block is None:
                    return None
                true_id = id(true_block)
                if true_id not in nodes:
                    queue.append((true_block, true_env, true_redirect))

            if false_disc == end_state:
                false_id = _VIRTUAL_EXIT
            else:
                false_block = _lookup_block(machine, false_disc, false_env)
                if false_block is None:
                    return None
                false_id = id(false_block)
                if false_id not in nodes:
                    queue.append((false_block, false_env, false_redirect))

            successors = [true_id, false_id]
            true_prefix_payload = _process_branch_prefix(
                transition.true_prefix, var_names, state, match, strip_ns, redirect_target,
            )
            false_prefix_payload = _process_branch_prefix(
                transition.false_prefix, var_names, state, match, strip_ns, redirect_target,
            )
            if match.with_redirect_var and match.scope_default_props:
                condition = _qualify_condition(condition, state, match, redirect_target)
            else:
                wrapper = JsExpressionStatement(expression=_clone_node(condition))
                _substitute_in_scope(wrapper, state)
                if match.scope_param_name:
                    _strip_scope_prefix_walk(wrapper, match.scope_param_name, strip_ns)
                condition = wrapper.expression  # type: ignore[assignment]

        node = _CFGNode(
            node_id=node_id,
            payload=payload,
            condition=condition,
            successors=successors,
            true_prefix_payload=true_prefix_payload,
            false_prefix_payload=false_prefix_payload,
        )
        nodes[node_id] = node
        node_envs[node_id] = state

    entry_id = id(entry_block)
    if entry_id not in nodes:
        return None

    exit_node = _CFGNode(node_id=_VIRTUAL_EXIT, payload=[], condition=None)
    nodes[_VIRTUAL_EXIT] = exit_node

    for n in nodes.values():
        for succ_id in n.successors:
            if succ_id in nodes:
                nodes[succ_id].predecessors.append(n.node_id)

    return (_CFG(nodes=nodes, entry=entry_id, exit=_VIRTUAL_EXIT), routing_state)


def _compute_idom(cfg: _CFG) -> dict[int, int | None]:
    """
    Compute immediate dominators using the Cooper-Harvey-Kennedy iterative algorithm.
    """
    entry = cfg.entry
    order = _reverse_postorder(cfg)
    node_to_idx = {d: i for i, d in enumerate(order)}
    idom: dict[int, int | None] = {entry: None}

    def intersect(a: int, b: int) -> int:
        ai = node_to_idx[a]
        bi = node_to_idx[b]
        while ai != bi:
            while ai > bi:
                a = idom[a]  # type: ignore
                ai = node_to_idx[a]
            while bi > ai:
                b = idom[b]  # type: ignore
                bi = node_to_idx[b]
        return a

    changed = True
    while changed:
        changed = False
        for disc in order:
            if disc == entry:
                continue
            node = cfg.nodes[disc]
            preds = [p for p in node.predecessors if p in idom]
            if not preds:
                continue
            new_idom = preds[0]
            for p in preds[1:]:
                new_idom = intersect(new_idom, p)
            if idom.get(disc) != new_idom:
                idom[disc] = new_idom
                changed = True

    return idom


def _reverse_postorder(cfg: _CFG) -> list[int]:
    """
    Compute reverse postorder traversal of the CFG from entry.
    """
    visited: set[int] = set()
    order: list[int] = []

    def dfs(disc: int):
        stack: list[tuple[int, int]] = [(disc, 0)]
        while stack:
            current, idx = stack.pop()
            if idx == 0:
                if current in visited:
                    continue
                visited.add(current)
            node = cfg.nodes.get(current)
            if node is None:
                order.append(current)
                continue
            succs = [s for s in node.successors if s in cfg.nodes]
            if idx < len(succs):
                stack.append((current, idx + 1))
                s = succs[idx]
                if s not in visited:
                    stack.append((s, 0))
            else:
                order.append(current)

    dfs(cfg.entry)
    order.reverse()
    return order


def _dominates(idom: dict[int, int | None], a: int, b: int) -> bool:
    """
    Check if node `a` dominates node `b`.
    """
    current = b
    while current is not None:
        if current == a:
            return True
        current = idom.get(current)
    return False


def _compute_ipdom(
    cfg: _CFG,
    exit_id: int,
    region: set[int] | None = None,
) -> dict[int, int | None]:
    """
    Compute immediate post-dominators using Cooper-Harvey-Kennedy on the reverse CFG.
    Post-dominator of X = first node Y that ALL paths from X to exit must pass through.
    """
    exit_preds: list[int] = []
    if exit_id not in cfg.nodes:
        for nid, node in cfg.nodes.items():
            if region is not None and nid not in region:
                continue
            if exit_id in node.successors:
                exit_preds.append(nid)

    visited: set[int] = set()
    rpo: list[int] = []

    def _get_reverse_succs(nid: int) -> list[int]:
        node = cfg.nodes.get(nid)
        if node is None:
            if nid == exit_id:
                return exit_preds
            return []
        preds = node.predecessors
        if region is not None:
            preds = [p for p in preds if p in region]
        return preds

    stack: list[tuple[int, int]] = [(exit_id, 0)]
    while stack:
        current, idx = stack.pop()
        if idx == 0:
            if current in visited:
                continue
            visited.add(current)
        preds = _get_reverse_succs(current)
        if idx < len(preds):
            stack.append((current, idx + 1))
            p = preds[idx]
            if p not in visited:
                stack.append((p, 0))
        else:
            rpo.append(current)

    rpo.reverse()
    node_to_idx = {d: i for i, d in enumerate(rpo)}
    ipdom: dict[int, int | None] = {exit_id: None}

    def intersect(a: int, b: int) -> int:
        ai: int = node_to_idx[a]
        bi: int = node_to_idx[b]
        while ai != bi:
            while ai > bi:
                a = ipdom[a]  # type: ignore
                ai = node_to_idx[a]
            while bi > ai:
                b = ipdom[b]  # type: ignore
                bi = node_to_idx[b]
        return a

    changed = True
    while changed:
        changed = False
        for disc in rpo:
            if disc == exit_id:
                continue
            node = cfg.nodes.get(disc)
            if node is None:
                continue
            succs = [s for s in node.successors if s in ipdom]
            if region is not None:
                succs = [s for s in succs if s in region or s == exit_id]
            if not succs:
                continue
            new_ipdom = succs[0]
            for s in succs[1:]:
                new_ipdom = intersect(new_ipdom, s)
            if ipdom.get(disc) != new_ipdom:
                ipdom[disc] = new_ipdom
                changed = True

    return ipdom


def _find_loops(cfg: _CFG, idom: dict[int, int | None]) -> list[_NaturalLoop]:
    """
    Identify natural loops from back-edges. A back-edge is (tail -> header) where header
    dominates tail. The loop body is the set of nodes that can reach the tail without leaving
    the header's dominance.
    """
    back_edges: list[tuple[int, int]] = []
    for disc, node in cfg.nodes.items():
        for succ in node.successors:
            if succ in cfg.nodes and _dominates(idom, succ, disc):
                back_edges.append((disc, succ))

    loops_by_header: dict[int, _NaturalLoop] = {}
    for tail, header in back_edges:
        if header not in loops_by_header:
            body = _compute_loop_body(cfg, header, tail)
            exits: set[int] = set()
            for b in body:
                n = cfg.nodes[b]
                for s in n.successors:
                    if s not in body and s in cfg.nodes:
                        exits.add(b)
            loops_by_header[header] = _NaturalLoop(
                header=header, body=body, tails=[tail], exits=exits,
            )
        else:
            loop = loops_by_header[header]
            loop.tails.append(tail)
            extra = _compute_loop_body(cfg, header, tail)
            loop.body |= extra
            for b in loop.body:
                n = cfg.nodes[b]
                for s in n.successors:
                    if s not in loop.body and s in cfg.nodes:
                        loop.exits.add(b)

    return list(loops_by_header.values())


def _compute_loop_body(cfg: _CFG, header: int, tail: int) -> set[int]:
    """
    Compute the natural loop body: all nodes that can reach `tail` without going through
    `header`, plus `header` itself.
    """
    body: set[int] = {header}
    if tail == header:
        return body
    body.add(tail)
    worklist: list[int] = [tail]
    while worklist:
        node_disc = worklist.pop()
        n = cfg.nodes.get(node_disc)
        if n is None:
            continue
        for pred in n.predecessors:
            if pred not in body and pred in cfg.nodes:
                body.add(pred)
                worklist.append(pred)
    return body


def _structural_analysis(
    cfg: _CFG,
    idom: dict[int, int | None],
    loops: list[_NaturalLoop],
) -> list[Statement]:
    """
    Recover structured control flow from the CFG using region-based structural analysis.
    Process loops innermost-first, then structure acyclic regions.
    """
    sorted_loops = _sort_loops_innermost_first(loops)
    collapsed: dict[int, list[Statement]] = {}
    loop_headers: set[int] = set()

    for loop in sorted_loops:
        loop_headers.add(loop.header)
        body_stmts = _structure_loop(cfg, loop, idom, collapsed)
        collapsed[loop.header] = body_stmts
        for body_node in loop.body:
            if body_node != loop.header and body_node not in collapsed:
                collapsed[body_node] = []

    return _structure_acyclic_region(cfg, cfg.entry, cfg.exit, idom, collapsed, loop_headers)


def _sort_loops_innermost_first(loops: list[_NaturalLoop]) -> list[_NaturalLoop]:
    """
    Sort loops so that inner (smaller body) loops are processed before outer ones.
    """
    return sorted(loops, key=lambda lp: len(lp.body))


def _structure_loop(
    cfg: _CFG,
    loop: _NaturalLoop,
    idom: dict[int, int | None],
    collapsed: dict[int, list[Statement]],
) -> list[Statement]:
    """
    Structure a single natural loop into a while/do-while statement.
    """
    header = loop.header
    header_node = cfg.nodes[header]

    if (
        header_node.condition is not None
        and len(header_node.successors) == 2
        and not header_node.payload
    ):
        true_succ, false_succ = header_node.successors
        if true_succ not in loop.body and true_succ in cfg.nodes:
            body_entry = false_succ
            condition = JsUnaryExpression(operator='!', operand=header_node.condition, prefix=True)
            body_prefix = header_node.false_prefix_payload
            exit_prefix = header_node.true_prefix_payload
        elif false_succ not in loop.body and false_succ in cfg.nodes:
            body_entry = true_succ
            condition = header_node.condition
            body_prefix = header_node.true_prefix_payload
            exit_prefix = header_node.false_prefix_payload
        else:
            return _structure_loop_infinite(cfg, loop, idom, collapsed)

        body_stmts = _structure_acyclic_region(
            cfg, body_entry, header, idom, collapsed, set(),
            loop_body=loop.body,
        )
        body_stmts = list(header_node.payload) + list(body_prefix) + body_stmts
        while_stmt = JsWhileStatement(
            test=condition,
            body=JsBlockStatement(body=body_stmts),
        )
        return [while_stmt] + list(exit_prefix)

    return _structure_loop_infinite(cfg, loop, idom, collapsed)


def _structure_loop_infinite(
    cfg: _CFG,
    loop: _NaturalLoop,
    idom: dict[int, int | None],
    collapsed: dict[int, list[Statement]],
) -> list[Statement]:
    """
    Structure a loop that doesn't have a simple while-condition as `while(true)` with breaks.
    """
    header = loop.header
    body_stmts = _structure_region_nodes(cfg, header, idom, collapsed, loop.body)
    while_stmt = JsWhileStatement(
        test=JsBooleanLiteral(value=True),
        body=JsBlockStatement(body=body_stmts),
    )
    return [while_stmt]


def _structure_acyclic_region(
    cfg: _CFG,
    entry: int,
    exit_disc: int,
    idom: dict[int, int | None],
    collapsed: dict[int, list[Statement]],
    loop_headers: set[int],
    loop_body: set[int] | None = None,
    _visited: set[int] | None = None,
) -> list[Statement]:
    """
    Structure an acyclic region from `entry` to `exit_disc` into a statement sequence.
    Handles if/else patterns using post-dominator-based join detection.
    """
    result: list[Statement] = []
    visited: set[int] = _visited if _visited is not None else set()
    worklist: deque[int] = deque([entry])

    while worklist:
        disc = worklist.popleft()
        if disc == exit_disc or disc == _VIRTUAL_EXIT:
            continue
        if disc in visited:
            continue
        if loop_body is not None and disc not in loop_body:
            continue
        visited.add(disc)

        if disc in collapsed:
            result.extend(collapsed[disc])
            node = cfg.nodes[disc]
            for s in node.successors:
                if s not in visited and s != exit_disc and s != _VIRTUAL_EXIT:
                    if loop_body is None or s in loop_body:
                        worklist.append(s)
            continue

        node = cfg.nodes.get(disc)
        if node is None:
            continue

        if node.condition is not None and len(node.successors) == 2:
            result.extend(node.payload)
            true_succ, false_succ = node.successors
            join = _find_acyclic_join(cfg, disc, exit_disc, loop_body)

            true_stmts: list[Statement] = list(node.true_prefix_payload)
            true_visited = set(visited)
            if true_succ != join and true_succ not in visited:
                true_stmts.extend(_structure_acyclic_region(
                    cfg, true_succ, join, idom, collapsed, loop_headers, loop_body, true_visited,
                ))
            false_stmts: list[Statement] = list(node.false_prefix_payload)
            false_visited = set(visited)
            if false_succ != join and false_succ not in visited:
                false_stmts.extend(_structure_acyclic_region(
                    cfg, false_succ, join, idom, collapsed, loop_headers, loop_body, false_visited,
                ))
            visited.update(true_visited)
            visited.update(false_visited)

            if_stmt = _build_js_if(node.condition, true_stmts, false_stmts)
            if if_stmt is not None:
                result.append(if_stmt)

            if join != _VIRTUAL_EXIT and join != exit_disc and join not in visited:
                worklist.appendleft(join)
        else:
            result.extend(node.payload)
            for s in node.successors:
                if s == exit_disc or s == _VIRTUAL_EXIT:
                    continue
                if s in visited:
                    continue
                if loop_body is not None and s not in loop_body:
                    result.append(JsBreakStatement())
                    continue
                worklist.append(s)

    return result


def _find_acyclic_join(
    cfg: _CFG,
    cond_disc: int,
    region_exit: int,
    loop_body: set[int] | None,
) -> int:
    """
    Find the join point of a conditional by computing its immediate post-dominator within
    the region. The ipdom is the first node where ALL paths from both successors converge.
    """
    region: set[int] = set()
    queue: deque[int] = deque([cond_disc])
    while queue:
        d = queue.popleft()
        if d in region or d == _VIRTUAL_EXIT:
            continue
        if d == region_exit:
            region.add(d)
            continue
        if loop_body is not None and d not in loop_body:
            continue
        region.add(d)
        node = cfg.nodes.get(d)
        if node is not None:
            for s in node.successors:
                if s not in region:
                    queue.append(s)

    if not region or cond_disc not in region:
        return region_exit

    region.add(region_exit)
    ipdom = _compute_ipdom(cfg, region_exit, region)
    join = ipdom.get(cond_disc)
    if join is None or (loop_body is not None and join not in loop_body and join != region_exit):
        return region_exit
    return join


def _structure_region_nodes(
    cfg: _CFG,
    header: int,
    idom: dict[int, int | None],
    collapsed: dict[int, list[Statement]],
    loop_body: set[int],
) -> list[Statement]:
    """
    Structure a set of CFG nodes that form a loop body, starting from the header.
    """
    result: list[Statement] = []
    visited: set[int] = set()
    worklist: deque[int] = deque([header])

    while worklist:
        disc = worklist.popleft()
        if disc in visited:
            continue
        if disc not in loop_body:
            result.append(JsBreakStatement())
            continue
        visited.add(disc)

        if disc in collapsed:
            result.extend(collapsed[disc])
            node = cfg.nodes[disc]
            for s in node.successors:
                if s not in visited and s in loop_body:
                    worklist.append(s)
            continue

        node = cfg.nodes.get(disc)
        if node is None:
            continue

        if node.condition is not None and len(node.successors) == 2:
            result.extend(node.payload)
            true_succ, false_succ = node.successors

            true_in_loop = true_succ in loop_body
            false_in_loop = false_succ in loop_body

            if true_succ == header:
                if false_succ not in loop_body:
                    neg = JsUnaryExpression(operator='!', operand=node.condition, prefix=True)
                    break_body = list(node.false_prefix_payload) + [JsBreakStatement()]
                    result.append(JsIfStatement(
                        test=neg,
                        consequent=JsBlockStatement(body=break_body),
                    ))
                    result.extend(node.true_prefix_payload)
                else:
                    continue_body = list(node.true_prefix_payload) + [JsContinueStatement()]
                    result.append(JsIfStatement(
                        test=node.condition,
                        consequent=JsBlockStatement(body=continue_body),
                    ))
                    result.extend(node.false_prefix_payload)
                    worklist.append(false_succ)
                continue
            elif false_succ == header:
                if true_succ not in loop_body:
                    break_body = list(node.true_prefix_payload) + [JsBreakStatement()]
                    result.append(JsIfStatement(
                        test=node.condition,
                        consequent=JsBlockStatement(body=break_body),
                    ))
                    result.extend(node.false_prefix_payload)
                else:
                    neg = JsUnaryExpression(operator='!', operand=node.condition, prefix=True)
                    continue_body = list(node.false_prefix_payload) + [JsContinueStatement()]
                    result.append(JsIfStatement(
                        test=neg,
                        consequent=JsBlockStatement(body=continue_body),
                    ))
                    result.extend(node.true_prefix_payload)
                    worklist.append(true_succ)
                continue

            if not true_in_loop and not false_in_loop:
                if node.true_prefix_payload or node.false_prefix_payload:
                    true_body = list(node.true_prefix_payload) + [JsBreakStatement()]
                    false_body = list(node.false_prefix_payload) + [JsBreakStatement()]
                    if_stmt = _build_js_if(node.condition, true_body, false_body)
                    if if_stmt is not None:
                        result.append(if_stmt)
                    else:
                        result.append(JsBreakStatement())
                else:
                    result.append(JsBreakStatement())
                continue
            if not true_in_loop:
                break_body = list(node.true_prefix_payload) + [JsBreakStatement()]
                result.append(JsIfStatement(
                    test=node.condition,
                    consequent=JsBlockStatement(body=break_body),
                ))
                result.extend(node.false_prefix_payload)
                worklist.append(false_succ)
                continue
            if not false_in_loop:
                neg = JsUnaryExpression(operator='!', operand=node.condition, prefix=True)
                break_body = list(node.false_prefix_payload) + [JsBreakStatement()]
                result.append(JsIfStatement(
                    test=neg,
                    consequent=JsBlockStatement(body=break_body),
                ))
                result.extend(node.true_prefix_payload)
                worklist.append(true_succ)
                continue

            join = _find_acyclic_join(cfg, disc, header, loop_body)
            true_stmts: list[Statement] = list(node.true_prefix_payload)
            true_visited = set(visited)
            if true_succ != join and true_succ not in visited:
                true_stmts.extend(_structure_acyclic_region(
                    cfg, true_succ, join, idom, collapsed, set(), loop_body, true_visited,
                ))
            false_stmts: list[Statement] = list(node.false_prefix_payload)
            false_visited = set(visited)
            if false_succ != join and false_succ not in visited:
                false_stmts.extend(_structure_acyclic_region(
                    cfg, false_succ, join, idom, collapsed, set(), loop_body, false_visited,
                ))
            visited.update(true_visited)
            visited.update(false_visited)
            if_stmt = _build_js_if(node.condition, true_stmts, false_stmts)
            if if_stmt is not None:
                result.append(if_stmt)
            if join != header and join in loop_body and join not in visited:
                worklist.appendleft(join)
        else:
            result.extend(node.payload)
            for s in node.successors:
                if s == header:
                    continue
                if s not in loop_body:
                    result.append(JsBreakStatement())
                    continue
                if s in visited:
                    continue
                worklist.append(s)

    return result


def _substitute_state_vars(stmts: list[Statement], env: _StateEnv) -> list[Statement]:
    """
    Clone statements and replace state variable identifiers with numeric literals. Stops at
    function boundaries to avoid replacing reused names in nested scopes.
    """
    result: list[Statement] = []
    for stmt in stmts:
        cloned = _clone_node(stmt)
        _substitute_in_scope(cloned, env)
        result.append(cloned)
    return result


def _substitute_in_scope(node: Node, env: _StateEnv) -> None:
    """
    Replace state variable identifiers with numeric literals, skipping into nested functions.
    """
    for child in node.children():
        if isinstance(child, (JsFunctionExpression, JsFunctionDeclaration)):
            continue
        if isinstance(child, JsIdentifier) and child.name in env:
            _replace_in_parent(child, make_numeric_literal(env[child.name]))
        else:
            _substitute_in_scope(child, env)


def _strip_scope_param_prefix(
    stmts: list[Statement],
    scope_param_name: str | None,
    namespace: str | None = None,
) -> list[Statement]:
    """
    Remove the scope parameter prefix from member chains. When `namespace` is provided (for
    redirect-aware qualification), `scope.X` becomes `namespace.X` directly — preserving the
    root-level qualification so that subsequent bare-identifier qualification only applies to
    identifiers that resolved through the `with` statement.
    Without `namespace`, `scope.X` becomes bare `X` (legacy behavior).
    """
    if scope_param_name is None:
        return stmts
    for stmt in stmts:
        _strip_scope_prefix_walk(stmt, scope_param_name, namespace)
    return stmts


def _strip_scope_prefix_walk(node: Node, scope_param_name: str, namespace: str | None = None) -> None:
    for child in node.children():
        if isinstance(child, JsMemberExpression) and isinstance(child.object, JsIdentifier):
            if child.object.name != scope_param_name:
                _strip_scope_prefix_walk(child, scope_param_name, namespace)
                continue
            if child.computed:
                if not isinstance(child.property, JsStringLiteral):
                    _strip_scope_prefix_walk(child, scope_param_name, namespace)
                    continue
                prop_name = child.property.value
            else:
                if not isinstance(child.property, JsIdentifier):
                    _strip_scope_prefix_walk(child, scope_param_name, namespace)
                    continue
                prop_name = child.property.name
            if namespace is not None and prop_name != namespace:
                replacement = JsMemberExpression(
                    object=JsIdentifier(name=namespace),
                    property=JsIdentifier(name=prop_name),
                    computed=False,
                )
            else:
                replacement = JsIdentifier(name=prop_name)
            _replace_in_parent(child, replacement)
            continue
        _strip_scope_prefix_walk(child, scope_param_name, namespace)


def _qualify_condition(
    condition: Expression,
    state: _StateEnv,
    match: _GeneratorCFFMatch,
    redirect_target: str | None,
) -> Expression:
    """
    Clone, substitute, strip, and qualify a transition condition expression using the same
    pipeline as block payloads. Wraps in a synthetic statement so that root-node scope members
    and identifiers are processed correctly.
    """
    wrapper = JsExpressionStatement(expression=_clone_node(condition))
    _substitute_in_scope(wrapper, state)
    if match.scope_param_name:
        strip_ns = match.scope_default_props[0] if redirect_target and match.scope_default_props else None
        _strip_scope_prefix_walk(wrapper, match.scope_param_name, strip_ns)
    if match.with_redirect_var and len(match.scope_default_props) == 1:
        namespace = match.scope_default_props[0]
        exempt: set[str] = set(match.state_var_names) | _JS_BUILTIN_GLOBALS
        exempt.add(namespace)
        exempt.add(match.generator_name)
        if match.scope_param_name:
            exempt.add(match.scope_param_name)
        if match.arg_var_name:
            exempt.add(match.arg_var_name)
        if match.did_return_var:
            exempt.add(match.did_return_var)
        if redirect_target and redirect_target != namespace:
            exempt.add(redirect_target)
        ns_path: list[str] = [namespace]
        if redirect_target and redirect_target != namespace:
            ns_path.append(redirect_target)
        _qualify_bare_walk(wrapper, ns_path, exempt)
    return wrapper.expression  # type: ignore[return-value]


def _extract_redirect_target(
    payload: list[Statement],
    scope_param_name: str | None,
    redirect_var: str | None,
) -> str | None:
    """
    Scan pre-strip payload for a redirect variable assignment of the form

        scope.redirect_var = scope.TARGET

    (or computed equivalent) and return the TARGET name.
    Returns the LAST such assignment found (assignments may be overwritten).
    """
    if scope_param_name is None or redirect_var is None:
        return None
    target: str | None = None
    for stmt in payload:
        if not isinstance(stmt, JsExpressionStatement):
            continue
        expr = stmt.expression
        exprs = expr.expressions if isinstance(expr, JsSequenceExpression) else [expr]
        for e in exprs:
            if not isinstance(e, JsAssignmentExpression) or e.operator != '=':
                continue
            lhs = e.left
            if not isinstance(lhs, JsMemberExpression):
                continue
            if not isinstance(lhs.object, JsIdentifier) or lhs.object.name != scope_param_name:
                continue
            if lhs.computed:
                if not isinstance(lhs.property, JsStringLiteral) or lhs.property.value != redirect_var:
                    continue
            elif not isinstance(lhs.property, JsIdentifier) or lhs.property.name != redirect_var:
                continue
            rhs = e.right
            if not isinstance(rhs, JsMemberExpression):
                continue
            if not isinstance(rhs.object, JsIdentifier) or rhs.object.name != scope_param_name:
                continue
            if rhs.computed:
                if isinstance(rhs.property, JsStringLiteral):
                    target = rhs.property.value
            elif isinstance(rhs.property, JsIdentifier):
                target = rhs.property.name
    return target


_JS_BUILTIN_GLOBALS: frozenset[str] = frozenset({
    'globalThis',
    'global',
    'self',
    'window',
    'undefined',
    'NaN',
    'Infinity',
    'eval',
    'isNaN',
    'isFinite',
    'parseInt',
    'parseFloat',
    'decodeURI',
    'decodeURIComponent',
    'encodeURI',
    'encodeURIComponent',
    'Object',
    'Function',
    'Boolean',
    'Symbol',
    'Number',
    'BigInt',
    'Math',
    'Date',
    'String',
    'RegExp',
    'Array',
    'Map',
    'Set',
    'WeakMap',
    'WeakSet',
    'ArrayBuffer',
    'SharedArrayBuffer',
    'DataView',
    'JSON',
    'Promise',
    'Reflect',
    'Proxy',
    'Error',
    'TypeError',
    'RangeError',
    'ReferenceError',
    'SyntaxError',
    'URIError',
    'EvalError',
    'console',
    'setTimeout',
    'setInterval',
    'clearTimeout',
    'clearInterval',
    'require',
    'module',
    'exports',
    'process',
    'Buffer',
    'URL',
    'URLSearchParams',
    'Intl',
    'Atomics',
    'WebAssembly',
})


def _qualify_with_identifiers(
    stmts: list[Statement],
    match: _GeneratorCFFMatch,
    redirect_target: str | None = None,
) -> list[Statement]:
    """
    Qualify bare identifiers that resolved through a with-scope redirect by prepending the
    namespace. Only applies when the with-redirect pattern is detected and there is exactly one
    namespace. When a redirect_target is active, the qualification path becomes

        NS.redirect_target.X

    instead of just `NS.X`, reflecting the with-scope resolution.
    """
    if not match.with_redirect_var or len(match.scope_default_props) != 1:
        return stmts
    namespace = match.scope_default_props[0]
    exempt: set[str] = set(match.state_var_names) | _JS_BUILTIN_GLOBALS
    exempt.add(namespace)
    exempt.add(match.generator_name)
    if match.scope_param_name:
        exempt.add(match.scope_param_name)
    if match.arg_var_name:
        exempt.add(match.arg_var_name)
    if match.did_return_var:
        exempt.add(match.did_return_var)
    if redirect_target and redirect_target != namespace:
        exempt.add(redirect_target)
    ns_path: list[str] = [namespace]
    if redirect_target and redirect_target != namespace:
        ns_path.append(redirect_target)
    for stmt in stmts:
        _qualify_bare_walk(stmt, ns_path, exempt)
    _convert_function_declarations(stmts, ns_path, exempt)
    return stmts


def _convert_function_declarations(
    stmts: list[Statement],
    ns_path: list[str],
    exempt: set[str],
    owner: Node | None = None,
) -> None:
    """
    Convert function declarations whose names are not exempt into namespace property assignments.
    This ensures that `function foo(...)` becomes `NS.foo = function(...)` so that all references
    to the function consistently go through the namespace. Recurses into block bodies but not into
    function bodies.
    """
    for i, stmt in enumerate(stmts):
        if isinstance(stmt, JsFunctionDeclaration):
            if stmt.id is not None and stmt.id.name not in exempt:
                name = stmt.id.name
                func_expr = JsFunctionExpression(
                    id=None,
                    params=stmt.params,
                    body=stmt.body,
                )
                target = JsMemberExpression(
                    object=_make_namespace_node(ns_path),
                    property=JsIdentifier(name=name),
                    computed=False,
                )
                assignment = JsAssignmentExpression(operator='=', left=target, right=func_expr)
                stmts[i] = JsExpressionStatement(expression=assignment)
                if owner is not None:
                    stmts[i].parent = owner
            continue
        if isinstance(stmt, (JsFunctionExpression, JsBlockStatement)):
            continue
        for child in stmt.children():
            if isinstance(child, JsBlockStatement):
                _convert_function_declarations(child.body, ns_path, exempt, owner=child)


def _make_namespace_node(ns_path: list[str]) -> Expression:
    """
    Build an AST node for a namespace path: single identifier for length 1,
    nested member expressions for longer paths.
    """
    node: Expression = JsIdentifier(name=ns_path[0])
    for segment in ns_path[1:]:
        node = JsMemberExpression(
            object=node,
            property=JsIdentifier(name=segment),
            computed=False,
        )
    return node


def _qualify_bare_walk(node: Node, ns_path: list[str], exempt: set[str]) -> None:
    for child in node.children():
        if isinstance(child, (JsFunctionExpression, JsFunctionDeclaration)):
            inner_exempt = exempt | _collect_declared_names(child)
            _qualify_bare_walk(child, ns_path, inner_exempt)
            continue
        if isinstance(child, JsIdentifier) and child.name not in exempt:
            parent = child.parent
            if isinstance(parent, JsMemberExpression) and parent.property is child and not parent.computed:
                continue
            if isinstance(parent, JsProperty) and parent.key is child and not parent.computed:
                continue
            if isinstance(parent, (JsVariableDeclarator, JsRestElement)):
                exempt.add(child.name)
                continue
            if isinstance(parent, JsCatchClause) and parent.param is child:
                exempt.add(child.name)
                continue
            if isinstance(parent, (JsLabeledStatement, JsContinueStatement, JsBreakStatement)):
                if getattr(parent, 'label', None) is child:
                    continue
            replacement = JsMemberExpression(
                object=_make_namespace_node(ns_path),
                property=JsIdentifier(name=child.name),
                computed=False,
            )
            _replace_in_parent(child, replacement)
            continue
        _qualify_bare_walk(child, ns_path, exempt)


def _collect_declared_names(func: JsFunctionExpression | JsFunctionDeclaration) -> set[str]:
    """
    Collect parameter names and var-declared names from a function for exemption. Only collects
    declarations at the function's own scope level — does not descend into nested functions.
    """
    names: set[str] = set()
    if isinstance(func, JsFunctionDeclaration) and func.id is not None:
        names.add(func.id.name)
    for p in (func.params or []):
        _collect_binding_names(p, names)
    if func.body is not None:
        queue: deque[Node] = deque(func.body.body)
        while queue:
            node = queue.popleft()
            if isinstance(node, (JsFunctionExpression, JsFunctionDeclaration)):
                if isinstance(node, JsFunctionDeclaration) and node.id is not None:
                    names.add(node.id.name)
                continue
            if isinstance(node, JsVariableDeclaration):
                for decl in node.declarations:
                    if isinstance(decl, JsVariableDeclarator):
                        _collect_binding_names(decl.id, names)
            for child in node.children():
                queue.append(child)
    return names


def _collect_binding_names(pattern: Expression | None, out: set[str]) -> None:
    """
    Recursively extract bound identifier names from a binding pattern (simple identifier,
    array pattern, object pattern, rest element, or assignment pattern with default).
    """
    if pattern is None:
        return
    if isinstance(pattern, JsIdentifier):
        out.add(pattern.name)
    elif isinstance(pattern, JsRestElement):
        _collect_binding_names(pattern.argument, out)
    elif isinstance(pattern, JsAssignmentPattern):
        _collect_binding_names(pattern.left, out)
    elif isinstance(pattern, JsArrayPattern):
        for el in pattern.elements:
            _collect_binding_names(el, out)
    elif isinstance(pattern, JsObjectPattern):
        for prop in pattern.properties:
            if isinstance(prop, JsRestElement):
                _collect_binding_names(prop.argument, out)
            elif isinstance(prop, JsProperty) and prop.value is not None:
                _collect_binding_names(prop.value, out)


def _is_did_return_assignment(expr: Expression, did_return_var: str | None) -> bool:
    """
    Check whether an expression is `didReturnVar = true`.
    """
    if did_return_var is None:
        return False
    if not isinstance(expr, JsAssignmentExpression):
        return False
    if not isinstance(expr.left, JsIdentifier):
        return False
    return expr.left.name == did_return_var and expr.operator == '='


def _recover_returns(stmts: list[Statement], did_return_var: str | None) -> list[Statement]:
    """
    Convert sequence expressions of the form

        (didReturn = true, value)

    into JsReturnStatement nodes. Also handles explicit return with the same pattern. Recurses
    into nested structures (if/else branches, while bodies) so that return patterns at any depth
    are recovered.
    """
    if did_return_var is None:
        return stmts
    result: list[Statement] = []
    for stmt in stmts:
        if isinstance(stmt, JsReturnStatement) and stmt.argument is not None:
            arg = stmt.argument
            if isinstance(arg, JsSequenceExpression) and len(arg.expressions) >= 2:
                if _is_did_return_assignment(arg.expressions[0], did_return_var):
                    ret_val = (
                        arg.expressions[1] if len(arg.expressions) == 2
                        else JsSequenceExpression(expressions=arg.expressions[1:])
                    )
                    result.append(JsReturnStatement(argument=ret_val))
                    continue
            result.append(stmt)
            continue
        if isinstance(stmt, JsExpressionStatement) and isinstance(stmt.expression, JsSequenceExpression):
            seq = stmt.expression
            if len(seq.expressions) >= 2 and _is_did_return_assignment(seq.expressions[0], did_return_var):
                ret_val = (
                    seq.expressions[1] if len(seq.expressions) == 2
                    else JsSequenceExpression(expressions=seq.expressions[1:])
                )
                result.append(JsReturnStatement(argument=ret_val))
                continue
        if isinstance(stmt, JsIfStatement):
            if stmt.consequent is not None and isinstance(stmt.consequent, JsBlockStatement):
                stmt.consequent.body = _recover_returns(stmt.consequent.body, did_return_var)
            if stmt.alternate is not None and isinstance(stmt.alternate, JsBlockStatement):
                stmt.alternate.body = _recover_returns(stmt.alternate.body, did_return_var)
            elif isinstance(stmt.alternate, JsIfStatement):
                recovered = _recover_returns([stmt.alternate], did_return_var)
                if recovered:
                    stmt.alternate = recovered[0]
        elif isinstance(stmt, JsWhileStatement):
            if stmt.body is not None and isinstance(stmt.body, JsBlockStatement):
                stmt.body.body = _recover_returns(stmt.body.body, did_return_var)
        result.append(stmt)
    return result


def _is_scope_bookkeeping(
    stmt: Statement,
    scope_param_name: str | None,
    arg_var_name: str | None = None,
) -> bool:
    """
    Check whether a statement is scope/with-discriminant bookkeeping (predicate initialization,
    scope object assignment, etc.) that should be suppressed in output. Only depth-1 scope-member
    assignments (direct slots on the scope parameter) are considered routing; deeper member chains
    write into nested objects and represent semantic data initialization.
    """
    if scope_param_name is None:
        return False
    if not isinstance(stmt, JsExpressionStatement):
        return False
    expr = stmt.expression
    if isinstance(expr, JsAssignmentExpression):
        if _is_direct_scope_member(expr.left, scope_param_name):
            return True
        if isinstance(expr.left, (JsArrayExpression, JsArrayPattern)) and expr.left.elements:
            if all(
                _is_direct_scope_member(e, scope_param_name)
                for e in expr.left.elements if e is not None
            ):
                return True
        if (
            arg_var_name is not None
            and isinstance(expr.left, (JsArrayExpression, JsArrayPattern))
            and isinstance(expr.right, JsIdentifier)
            and expr.right.name == arg_var_name
        ):
            return True
    if isinstance(expr, JsSequenceExpression):
        if all(
            isinstance(sub, JsAssignmentExpression)
            and _is_direct_scope_member(sub.left, scope_param_name)
            for sub in expr.expressions
        ):
            return True
    return False


def _strip_scope_bookkeeping_from_sequence(
    stmt: Statement,
    scope_param_name: str | None,
) -> Statement:
    """
    For sequence expressions containing a mix of bookkeeping and non-bookkeeping sub-expressions,
    return a new statement with only the non-bookkeeping expressions. Returns the original
    statement unchanged if no stripping is needed.
    """
    if scope_param_name is None:
        return stmt
    if not isinstance(stmt, JsExpressionStatement):
        return stmt
    expr = stmt.expression
    if not isinstance(expr, JsSequenceExpression):
        return stmt
    remaining: list[Expression] = []
    for sub in expr.expressions:
        if (
            isinstance(sub, JsAssignmentExpression)
            and _is_direct_scope_member(sub.left, scope_param_name)
        ):
            continue
        remaining.append(sub)
    if len(remaining) == len(expr.expressions):
        return stmt
    if not remaining:
        return stmt
    if len(remaining) == 1:
        return JsExpressionStatement(expression=remaining[0])
    return JsExpressionStatement(expression=JsSequenceExpression(expressions=remaining))


def _is_direct_scope_member(node: Expression | None, scope_param_name: str) -> bool:
    """
    Check if an expression is a depth-1 member access on the scope parameter, i.e. `scope.X` or
    `scope["X"]` but NOT `scope.X.Y`. Only direct slots are CFF routing state; deeper chains are
    semantic writes.
    """
    if not isinstance(node, JsMemberExpression):
        return False
    if not isinstance(node.object, JsIdentifier) or node.object.name != scope_param_name:
        return False
    if node.computed:
        return isinstance(node.property, JsStringLiteral)
    return True


def _extract_arg_param_names(
    stmt: Statement,
    arg_var_name: str | None,
) -> list[str] | None:
    """
    If *stmt* is the argument-destructuring pattern:

        [elem1, elem2, ...] = argVar

    extract parameter names from the LHS elements. Each element is expected to be a
    member-expression chain; the deepest property name is returned. Returns `None` if the
    statement is not the arg-destructuring pattern.
    """
    if arg_var_name is None:
        return None
    if not isinstance(stmt, JsExpressionStatement):
        return None
    expr = stmt.expression
    if not isinstance(expr, JsAssignmentExpression):
        return None
    if not isinstance(expr.left, (JsArrayExpression, JsArrayPattern)):
        return None
    if not isinstance(expr.right, JsIdentifier) or expr.right.name != arg_var_name:
        return None
    names: list[str] = []
    for elem in expr.left.elements:
        if elem is None:
            return None
        name = _deepest_property_name(elem)
        if name is None:
            return None
        names.append(name)
    return names


def _deepest_property_name(node: Node) -> str | None:
    """
    Walk a member-expression chain and return the deepest (rightmost) property name.
    """
    if isinstance(node, JsIdentifier):
        return node.name
    if isinstance(node, JsMemberExpression):
        if isinstance(node.property, JsIdentifier):
            return node.property.name
        if isinstance(node.property, JsStringLiteral):
            return node.property.value
    return None


def _is_redirect_var_write(stmt: Statement, namespace: str, redirect_var: str) -> bool:
    if not isinstance(stmt, JsExpressionStatement):
        return False
    expr = stmt.expression
    if not isinstance(expr, JsAssignmentExpression):
        return False
    lhs = expr.left
    return (
        isinstance(lhs, JsMemberExpression)
        and not lhs.computed
        and isinstance(lhs.object, JsIdentifier)
        and lhs.object.name == namespace
        and isinstance(lhs.property, JsIdentifier)
        and lhs.property.name == redirect_var
    )


def _filter_redirect_var_assignments(
    stmts: list[Statement],
    match: _GeneratorCFFMatch,
) -> list[Statement]:
    if not match.with_redirect_var or not match.scope_default_props:
        return stmts
    namespace = match.scope_default_props[0]
    redirect_var = match.with_redirect_var
    return [s for s in stmts if not _is_redirect_var_write(s, namespace, redirect_var)]


def _track_scope_routing(payload: list[Statement], state: _StateEnv) -> None:
    """
    Scan payload for assignments to scope member expressions with evaluable RHS values and record
    them in the state environment. This captures routing variables stored on scope objects.
    """
    for stmt in payload:
        if not isinstance(stmt, JsExpressionStatement):
            continue
        expr = stmt.expression
        if isinstance(expr, JsSequenceExpression):
            exprs = expr.expressions
        else:
            exprs = [expr]
        for e in exprs:
            if not isinstance(e, JsAssignmentExpression):
                continue
            if not isinstance(e.left, JsMemberExpression):
                continue
            if e.operator != '=':
                continue
            key = member_key(e.left)
            if key is None or e.right is None:
                continue
            val = _eval_expr(e.right, state)
            if val is not None:
                state[key] = val


def _execute_machine(
    machine: _StateMachine,
    match: _GeneratorCFFMatch,
    inherited_state: _StateEnv | None = None,
) -> tuple[list[Statement], _StateEnv] | None:
    """
    Recover structured code from the state machine using CFG-based structural analysis.
    Builds a control flow graph, identifies loops via dominator analysis, and emits
    structured control flow (while, if/else, break).
    """
    var_names = match.state_var_names
    state = _apply_initial_state(var_names, match.initial_state)
    if inherited_state:
        for k, v in inherited_state.items():
            if k not in var_names:
                state[k] = v

    cfg_result = _build_cfg(machine, state, var_names, match.end_state, match)
    if cfg_result is None:
        return None

    cfg, final_state = cfg_result
    idom = _compute_idom(cfg)
    loops = _find_loops(cfg, idom)
    stmts = _structural_analysis(cfg, idom, loops)
    recovered = _recover_returns(stmts, match.did_return_var)
    return (recovered, final_state)


def _build_js_if(
    condition: Expression,
    true_body: list[Statement],
    false_body: list[Statement],
) -> JsIfStatement | None:
    """
    Build a JsIfStatement, omitting empty branches.
    """
    if not true_body and not false_body:
        return None
    if not true_body:
        neg = JsUnaryExpression(operator='!', operand=condition, prefix=True)
        return JsIfStatement(
            test=neg,
            consequent=JsBlockStatement(body=false_body),
        )
    if not false_body:
        return JsIfStatement(
            test=condition,
            consequent=JsBlockStatement(body=true_body),
        )
    return JsIfStatement(
        test=condition,
        consequent=JsBlockStatement(body=true_body),
        alternate=JsBlockStatement(body=false_body),
    )


def _extract_sub_namespace_inits(
    scope_arg: JsObjectExpression | None,
    known_props: list[str],
) -> dict[str, Expression]:
    """
    Extract sub-namespace initializations from a wrapper's scope argument. Returns property names
    mapped to their init expressions for properties that are empty object literals and not already
    in *known_props*.
    """
    if scope_arg is None:
        return {}
    result: dict[str, Expression] = {}
    for prop in scope_arg.properties:
        if not isinstance(prop, JsProperty):
            continue
        key = property_key(prop)
        if key is None or key in known_props:
            continue
        if (
            isinstance(prop.value, JsObjectExpression)
            and not prop.value.properties
        ):
            result[key] = prop.value
    return result


def _resolve_shared_wrappers(
    stmts: list[Statement],
    machine: _StateMachine,
    match: _GeneratorCFFMatch,
    outer_state: _StateEnv,
) -> list[Statement]:
    """
    Walk recovered statements looking for function expressions that are wrappers around the same
    shared generator. For each wrapper found, execute the state machine from its entry point and
    replace the wrapper with a proper function containing the recovered body. The *outer_state*
    carries scope routing values from the primary execution so that predicate-gated cases in
    wrapper paths can resolve. Iterates until no more wrappers are resolved (handles nesting).
    """
    gen_name = match.generator_name
    num_vars = len(match.state_var_names)
    attempted: set[int] = set()
    sub_ns_inits: dict[str, Expression] = {}

    while True:
        resolved_any = False
        for node in list(_walk_all(stmts)):
            if not isinstance(node, JsFunctionExpression):
                continue
            node_id = id(node)
            if node_id in attempted:
                continue
            wrapper_info = _detect_wrapper_function(node, gen_name, num_vars)
            if wrapper_info is None:
                continue
            new_subs = _extract_sub_namespace_inits(
                wrapper_info.scope_arg, match.scope_default_props,
            )
            sub_ns_inits.update(new_subs)
            synthetic = _GeneratorCFFMatch(
                generator_name=gen_name,
                state_var_names=match.state_var_names,
                initial_state=wrapper_info.initial_state,
                end_state=match.end_state,
                switch_stmt=match.switch_stmt,
                switch_label=match.switch_label,
                scope_param_name=match.scope_param_name,
                arg_var_name=match.arg_var_name,
                did_return_var=match.did_return_var,
                result_var=None,
                gen_decl_index=0,
                scaffolding_end=0,
                with_redirect_var=match.with_redirect_var,
                scope_default_props=match.scope_default_props,
            )
            result = _execute_machine(machine, synthetic, inherited_state=outer_state)
            if result is None:
                attempted.add(node_id)
                continue
            recovered, _ = result
            if (
                match.arg_var_name
                and wrapper_info.rest_param_name
                and match.arg_var_name != wrapper_info.rest_param_name
            ):
                recovered = _rename_identifier(
                    recovered, match.arg_var_name, wrapper_info.rest_param_name,
                )
            node.body = JsBlockStatement(body=recovered)
            node.body.parent = node
            for s in recovered:
                s.parent = node.body
            if synthetic.arg_params:
                node.params = [JsIdentifier(name=n) for n in synthetic.arg_params]
            resolved_any = True
        if not resolved_any:
            break

    if sub_ns_inits and match.scope_default_props:
        namespace = match.scope_default_props[0]
        for sub_name in sorted(sub_ns_inits):
            assign = JsExpressionStatement(expression=JsAssignmentExpression(
                operator='=',
                left=JsMemberExpression(
                    object=JsIdentifier(name=namespace),
                    property=JsIdentifier(name=sub_name),
                    computed=False,
                ),
                right=_clone_node(sub_ns_inits[sub_name]),
            ))
            stmts.insert(0, assign)

    return stmts


def _walk_all(stmts: list[Statement]):
    """
    Yield all nodes reachable from a list of statements.
    """
    for stmt in stmts:
        yield from stmt.walk()


def _rename_identifier(stmts: list[Statement], old_name: str, new_name: str) -> list[Statement]:
    """
    Replace all occurrences of an identifier name in a statement list.
    """
    for stmt in stmts:
        for node in stmt.walk():
            if isinstance(node, JsIdentifier) and node.name == old_name:
                node.name = new_name
    return stmts


def _emit_scope_namespace_declarations(match: _GeneratorCFFMatch) -> list[Statement]:
    declarations: list[Statement] = []
    for name in match.scope_default_props:
        init = match.scope_default_inits.get(name)
        if init is None:
            init = JsObjectExpression(properties=[])
        decl = JsVariableDeclaration(
            declarations=[JsVariableDeclarator(
                id=JsIdentifier(name=name),
                init=_clone_node(init),
            )],
            kind=JsVarKind.VAR,
        )
        declarations.append(decl)
    return declarations


def _emit_arg_param_declarations(match: _GeneratorCFFMatch) -> list[Statement]:
    declarations: list[JsVariableDeclarator] = []
    for name in match.arg_params:
        declarations.append(JsVariableDeclarator(id=JsIdentifier(name=name), init=None))
    if not declarations:
        return []
    return [JsVariableDeclaration(declarations=declarations, kind=JsVarKind.VAR)]


class JsGeneratorCFFUnflattening(BodyProcessingTransformer):
    """
    Recover original code from generator-based state-machine CFF dispatchers. Handles the pattern
    where a function body is replaced with a generator function containing a while/switch state
    machine driven by multiple state variables.
    """

    def _process_body(self, parent: Node, body: list[Statement]) -> None:
        is_script = isinstance(parent, JsScript)
        i = 0
        while i < len(body):
            match = _match_generator_cff(body, i)
            if match is None:
                i += 1
                continue
            machine = _extract_state_blocks(match)
            if machine is None:
                i += 1
                continue
            result = _execute_machine(machine, match)
            if result is None:
                i += 1
                continue
            recovered, outer_state = result
            if match.arg_var_name is not None:
                recovered = _resolve_shared_wrappers(recovered, machine, match, outer_state)
            if match.scope_default_props:
                recovered = _emit_scope_namespace_declarations(match) + recovered
            if match.arg_params:
                recovered = _emit_arg_param_declarations(match) + recovered
            if is_script:
                recovered = self._sanitize_for_script_scope(recovered)
                if recovered is None:
                    i += 1
                    continue
            for s in recovered:
                s.parent = parent
            start = match.gen_decl_index
            end = match.scaffolding_end
            replacement = body[:start] + recovered + body[end + 1:]
            self._replace_body(parent, body, replacement)
            i = start + len(recovered)

    @staticmethod
    def _sanitize_for_script_scope(stmts: list[Statement]) -> list[Statement] | None:
        for stmt in stmts[:-1] if stmts else ():
            if isinstance(stmt, JsReturnStatement):
                return None
        if stmts and isinstance(stmts[-1], JsReturnStatement):
            last = stmts[-1]
            if last.argument is not None:
                stmts = stmts[:-1] + [JsExpressionStatement(expression=last.argument)]
            else:
                stmts = stmts[:-1]
        return stmts

Classes

class JsGeneratorCFFUnflattening

Recover original code from generator-based state-machine CFF dispatchers. Handles the pattern where a function body is replaced with a generator function containing a while/switch state machine driven by multiple state variables.

Expand source code Browse git
class JsGeneratorCFFUnflattening(BodyProcessingTransformer):
    """
    Recover original code from generator-based state-machine CFF dispatchers. Handles the pattern
    where a function body is replaced with a generator function containing a while/switch state
    machine driven by multiple state variables.
    """

    def _process_body(self, parent: Node, body: list[Statement]) -> None:
        is_script = isinstance(parent, JsScript)
        i = 0
        while i < len(body):
            match = _match_generator_cff(body, i)
            if match is None:
                i += 1
                continue
            machine = _extract_state_blocks(match)
            if machine is None:
                i += 1
                continue
            result = _execute_machine(machine, match)
            if result is None:
                i += 1
                continue
            recovered, outer_state = result
            if match.arg_var_name is not None:
                recovered = _resolve_shared_wrappers(recovered, machine, match, outer_state)
            if match.scope_default_props:
                recovered = _emit_scope_namespace_declarations(match) + recovered
            if match.arg_params:
                recovered = _emit_arg_param_declarations(match) + recovered
            if is_script:
                recovered = self._sanitize_for_script_scope(recovered)
                if recovered is None:
                    i += 1
                    continue
            for s in recovered:
                s.parent = parent
            start = match.gen_decl_index
            end = match.scaffolding_end
            replacement = body[:start] + recovered + body[end + 1:]
            self._replace_body(parent, body, replacement)
            i = start + len(recovered)

    @staticmethod
    def _sanitize_for_script_scope(stmts: list[Statement]) -> list[Statement] | None:
        for stmt in stmts[:-1] if stmts else ():
            if isinstance(stmt, JsReturnStatement):
                return None
        if stmts and isinstance(stmts[-1], JsReturnStatement):
            last = stmts[-1]
            if last.argument is not None:
                stmts = stmts[:-1] + [JsExpressionStatement(expression=last.argument)]
            else:
                stmts = stmts[:-1]
        return stmts

Ancestors