Module refinery.lib.scripts.ps1.deobfuscation.unflatten
Recover original control flow from control-flow-flattened PowerShell scripts.
Control flow flattening replaces sequential and branching code with a dispatcher loop: a while loop containing a single switch on a state variable, where each case sets the state variable to determine the next case to execute. This transformer identifies the dispatcher pattern, extracts the state machine, and recovers the original structure.
Expand source code Browse git
"""
Recover original control flow from control-flow-flattened PowerShell scripts.
Control flow flattening replaces sequential and branching code with a dispatcher loop: a while loop
containing a single switch on a state variable, where each case sets the state variable to
determine the next case to execute. This transformer identifies the dispatcher pattern, extracts
the state machine, and recovers the original structure.
"""
from __future__ import annotations
from collections import deque
from collections.abc import Callable, Generator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple
from refinery.lib.scripts import Block, Node, Statement, Transformer
from refinery.lib.scripts.ps1.deobfuscation.data import COMPARISON_OPS
from refinery.lib.scripts.ps1.deobfuscation.emulator import evaluate_truthy
from refinery.lib.scripts.ps1.deobfuscation.helpers import (
get_body,
inside_value_producing_context,
is_builtin_variable,
unwrap_parens,
)
from refinery.lib.scripts.ps1.model import (
Expression,
Ps1AssignmentExpression,
Ps1BinaryExpression,
Ps1BreakStatement,
Ps1ExpressionStatement,
Ps1IfStatement,
Ps1IntegerLiteral,
Ps1ParenExpression,
Ps1RealLiteral,
Ps1ScopeModifier,
Ps1StringLiteral,
Ps1SwitchStatement,
Ps1UnaryExpression,
Ps1Variable,
Ps1WhileLoop,
)
_MAX_STATES = 500
_MAX_UNROLL_ITERATIONS = 500
if TYPE_CHECKING:
_VarKey = tuple[str, Ps1ScopeModifier]
_StateKey = int | float | str
def _is_bool_literal(node: Node) -> bool | None:
"""
Check if a node is a `$True` or `$False` variable literal. Returns the boolean value,
or `None` if the node is not a boolean literal.
"""
if is_builtin_variable(node, frozenset({'true', 'false'})):
return node.name.lower() == 'true'
return None
def _unwrap_constant(node) -> _StateKey | None:
"""
Extract a constant value (int, float, or string) from an AST node.
"""
node = unwrap_parens(node) if isinstance(node, Expression) else node
if isinstance(node, Ps1IntegerLiteral):
return node.value
if isinstance(node, Ps1RealLiteral):
return node.value
if isinstance(node, Ps1StringLiteral):
return node.value
if isinstance(node, Ps1UnaryExpression) and node.operator == '-':
inner = unwrap_parens(node.operand) if isinstance(node.operand, Expression) else node.operand
if isinstance(inner, Ps1IntegerLiteral):
return -inner.value
if isinstance(inner, Ps1RealLiteral):
return -inner.value
if is_builtin_variable(node, frozenset({'null'})):
return 0
return None
@dataclass
class _LinearTransition:
target: _StateKey
@property
def successors(self) -> tuple[_StateKey, ...]:
return (self.target,)
@dataclass
class _ConditionalTransition:
condition: Expression
true_target: _StateKey
false_target: _StateKey
true_prefix: list[Statement] = field(default_factory=list)
false_prefix: list[Statement] = field(default_factory=list)
@property
def successors(self) -> tuple[_StateKey, ...]:
return (self.true_target, self.false_target)
@dataclass
class _ExitTransition:
@property
def successors(self) -> tuple[_StateKey, ...]:
return ()
if TYPE_CHECKING:
_Transition = _LinearTransition | _ConditionalTransition | _ExitTransition
@dataclass
class _StateBlock:
state_id: _StateKey
statements: list[Statement]
transition: _Transition
@dataclass
class _DispatcherMatch:
state_var_name: str
state_var_scope: Ps1ScopeModifier
condition: Expression
switch: Ps1SwitchStatement
def _var_key(node: Ps1Variable) -> _VarKey:
return (node.name.lower(), node.scope)
def _collect_variables(node: Node) -> set[_VarKey]:
"""
Collect all distinct variable keys referenced in an expression tree.
"""
result: set[_VarKey] = set()
for child in node.walk():
if isinstance(child, Ps1Variable):
if child.scope != Ps1ScopeModifier.NONE:
continue
lower = child.name.lower()
if lower in ('true', 'false', 'null'):
continue
result.add(_var_key(child))
return result
def _make_exit_check(
condition: Expression,
var_name: str,
) -> Callable[[_StateKey], bool]:
"""
Return a predicate that checks whether assigning a given state value to the state variable
would make the while condition falsy (the loop would exit). Uses the emulator to evaluate the
condition with the state variable bound.
"""
def is_exit(state_value: _StateKey) -> bool:
bindings = {var_name: state_value}
result = evaluate_truthy(condition, bindings)
if result is None:
return False
return not result
return is_exit
def _is_state_assignment(
stmt: Statement,
var_name: str,
var_scope: Ps1ScopeModifier,
) -> _StateKey | None:
"""
If the statement is an assignment to the state variable with a constant value, return that
value. Otherwise return `None`.
"""
result = _get_simple_assignment(stmt)
if result is None:
return None
key, value = result
if key != (var_name, var_scope):
return None
return _unwrap_constant(value)
def _get_simple_assignment(
stmt: Statement,
) -> tuple[_VarKey, Node] | None:
"""
If the statement is a simple assignment ($var = <value>), return the variable key
and the value node. Returns `None` for non-assignments or compound operators.
"""
if not isinstance(stmt, Ps1ExpressionStatement):
return None
expr = stmt.expression
if not isinstance(expr, Ps1AssignmentExpression):
return None
if expr.operator != '=':
return None
target = expr.target
if not isinstance(target, Ps1Variable):
return None
if expr.value is None:
return None
return (_var_key(target), expr.value)
def _resolve_value(
node: Node,
env: dict[_VarKey, _StateKey | bool],
) -> _StateKey | None:
"""
Try to resolve a node to a constant value given a variable environment.
"""
if not isinstance(node, Expression):
return None
expr = unwrap_parens(node)
value = _unwrap_constant(expr)
if value is not None:
return value
if isinstance(expr, Ps1Variable):
key = _var_key(expr)
val = env.get(key)
if isinstance(val, (int, float, str)):
return val
return None
def _resolve_bool(
node: Node,
env: dict[_VarKey, _StateKey | bool],
) -> bool | None:
"""
Try to resolve a node to a boolean given a variable environment.
"""
if not isinstance(node, Expression):
return None
expr = unwrap_parens(node)
literal = _is_bool_literal(expr)
if literal is not None:
return literal
if isinstance(expr, Ps1Variable):
key = _var_key(expr)
val = env.get(key)
if isinstance(val, bool):
return val
if (
isinstance(expr, Ps1BinaryExpression)
and (_l := expr.left) is not None
and (_r := expr.right) is not None
):
if comparator := COMPARISON_OPS.get(op := expr.operator.lower()):
lhs = _resolve_value(_l, env)
rhs = _resolve_value(_r, env)
if lhs is not None and rhs is not None:
if type(lhs) is not type(rhs):
return None
return comparator(lhs, rhs)
else:
lhs = _resolve_bool(_l, env)
rhs = _resolve_bool(_r, env)
if lhs is None:
lhs, rhs = rhs, lhs
if op == '-and':
return lhs and rhs
if op == '-or':
return lhs or rhs
if (
isinstance(expr, Ps1UnaryExpression)
and (_v := expr.operand) is not None
and (op := expr.operator.lower()) in ('-not', '!')
and (bv := _resolve_bool(_v, env)) is not None
):
return not bv
return None
def _match_dispatcher(loop: Ps1WhileLoop) -> _DispatcherMatch | None:
"""
Check whether a while loop matches the CFF dispatcher pattern: a while loop whose condition
involves a single variable, and whose body is a single switch on that variable.
"""
cond = loop.condition
if cond is None:
return None
body = loop.body
if body is None or len(body.body) != 1:
return None
switch = body.body[0]
if not isinstance(switch, Ps1SwitchStatement):
return None
if switch.file:
return None
switch_val = switch.value
if not isinstance(switch_val, Ps1Variable):
return None
var_name = switch_val.name.lower()
var_scope = switch_val.scope
cond_vars = _collect_variables(cond)
if not cond_vars:
return None
state_key = (var_name, var_scope)
if state_key not in cond_vars:
return None
if len(cond_vars) > 1:
return None
return _DispatcherMatch(
state_var_name=var_name,
state_var_scope=var_scope,
condition=cond,
switch=switch,
)
def _find_state_init(
body: list[Statement],
loop_index: int,
var_name: str,
var_scope: Ps1ScopeModifier,
) -> tuple[int, _StateKey] | None:
"""
Scan backwards from the while loop to find the state variable initialization. Returns
(index_in_body, initial_state_value) or `None`.
"""
for i in range(loop_index - 1, -1, -1):
value = _is_state_assignment(body[i], var_name, var_scope)
if value is not None:
return (i, value)
return None
def _strip_trailing_break(stmts: list[Statement]) -> list[Statement]:
"""
Remove a trailing break statement (part of the switch dispatch, not the original code).
"""
if stmts and isinstance(stmts[-1], Ps1BreakStatement) and stmts[-1].label is None:
return stmts[:-1]
return stmts
def _extract_transition(
stmts: list[Statement],
var_name: str,
var_scope: Ps1ScopeModifier,
is_exit: Callable[[_StateKey], bool],
) -> tuple[list[Statement], _Transition] | None:
"""
Separate a switch case body into side-effect statements and a state transition. Returns
(side_effects, transition) or `None` if the pattern is not recognized.
"""
if not stmts:
return None
last = stmts[-1]
state_val = _is_state_assignment(last, var_name, var_scope)
if state_val is not None:
side_effects = list(stmts[:-1])
if is_exit(state_val):
return (side_effects, _ExitTransition())
return (side_effects, _LinearTransition(target=state_val))
if isinstance(last, Ps1IfStatement) and last.else_block is not None and len(last.clauses) == 1:
condition, true_block = last.clauses[0]
false_block = last.else_block
true_body = _strip_trailing_break(list(true_block.body))
false_body = _strip_trailing_break(list(false_block.body))
if not true_body or not false_body:
return None
true_state = _is_state_assignment(true_body[-1], var_name, var_scope)
false_state = _is_state_assignment(false_body[-1], var_name, var_scope)
if true_state is None or false_state is None:
return None
true_prefix = true_body[:-1]
false_prefix = false_body[:-1]
side_effects = list(stmts[:-1])
return (side_effects, _ConditionalTransition(
condition=condition,
true_target=true_state,
false_target=false_state,
true_prefix=true_prefix,
false_prefix=false_prefix,
))
return None
def _extract_state_machine(
match: _DispatcherMatch,
is_exit: Callable[[_StateKey], bool],
) -> dict[_StateKey, _StateBlock] | None:
"""
Parse all switch cases into a state machine dictionary. Returns `None` on failure.
"""
states: dict[_StateKey, _StateBlock] = {}
var_name = match.state_var_name
var_scope = match.state_var_scope
for condition, block in match.switch.clauses:
if condition is None:
continue
state_id = _unwrap_constant(condition)
if state_id is None:
return None
body = _strip_trailing_break(list(block.body))
result = _extract_transition(body, var_name, var_scope, is_exit)
if result is None:
return None
side_effects, transition = result
states[state_id] = _StateBlock(
state_id=state_id,
statements=side_effects,
transition=transition,
)
if len(states) > _MAX_STATES:
return None
return states
def _negate_condition(cond: Expression) -> Expression:
"""
Return the logical negation of a condition expression. Tries to simplify where possible (e.g.,
flipping -eq to -ne) rather than wrapping in -Not.
"""
unwrapped = unwrap_parens(cond)
if not isinstance(unwrapped, Expression):
return Ps1UnaryExpression(operator='-Not', operand=cond, prefix=True)
cond = unwrapped
if isinstance(cond, Ps1BinaryExpression):
flipped = {
'-eq': '-NE',
'-ne': '-EQ',
'-lt': '-GE',
'-ge': '-LT',
'-gt': '-LE',
'-le': '-GT',
}.get(cond.operator.lower())
if flipped is not None:
return Ps1BinaryExpression(
left=cond.left,
operator=flipped,
right=cond.right,
)
if (
isinstance(cond, Ps1UnaryExpression)
and cond.operator.lower() in ('-not', '!')
and cond.operand is not None
):
return cond.operand
return Ps1UnaryExpression(operator='-Not', operand=Ps1ParenExpression(expression=cond))
def _build_if(
condition: Expression,
true_body: list[Statement],
false_body: list[Statement],
) -> Ps1IfStatement | None:
"""
Build an if/else statement. Returns `None` if both bodies are empty. Omits the else block if the
false body is empty; negates the condition if only the true body is empty.
"""
if not true_body and not false_body:
return None
if not true_body:
return Ps1IfStatement(
clauses=[(_negate_condition(condition), Block(body=false_body))],
)
if not false_body:
return Ps1IfStatement(
clauses=[(condition, Block(body=true_body))],
)
return Ps1IfStatement(
clauses=[(condition, Block(body=true_body))],
else_block=Block(body=false_body),
)
def _find_back_edges(
states: dict[_StateKey, _StateBlock],
entry: _StateKey,
is_exit: Callable[[_StateKey], bool],
) -> dict[_StateKey, _StateKey]:
"""
Walk the state graph from the entry and find back-edges using DFS. Returns a mapping from back-
edge source state to back-edge target state (the loop header).
"""
back_edges: dict[_StateKey, _StateKey] = {}
visiting: set[_StateKey] = set()
visited: set[_StateKey] = set()
def _dfs(state: _StateKey):
if is_exit(state) or state not in states:
return
if state in visited:
return
visiting.add(state)
block = states[state]
for target in block.transition.successors:
if is_exit(target) or target not in states:
continue
if target in visiting:
back_edges[state] = target
elif target not in visited:
_dfs(target)
visiting.discard(state)
visited.add(state)
_dfs(entry)
return back_edges
def _find_join_point(
states: dict[_StateKey, _StateBlock],
true_start: _StateKey,
false_start: _StateKey,
is_exit: Callable[[_StateKey], bool],
back_edge_targets: set[_StateKey],
) -> _StateKey | None:
"""
Find the first state reachable from both arms of a conditional (the join point). Returns `None`
if no common state is found (one or both arms exit).
"""
true_reach: list[_StateKey] = []
false_reach: list[_StateKey] = []
true_queue: deque[_StateKey] = deque([true_start])
false_queue: deque[_StateKey] = deque([false_start])
true_seen: set[_StateKey] = set()
false_seen: set[_StateKey] = set()
def _expand(
queue: deque[_StateKey],
seen: set[_StateKey],
reach: list[_StateKey],
):
while queue:
s = queue.popleft()
if is_exit(s) or s not in states:
continue
if s in seen:
continue
if s in back_edge_targets:
reach.append(s)
continue
seen.add(s)
reach.append(s)
for succ in states[s].transition.successors:
queue.append(succ)
_expand(true_queue, true_seen, true_reach)
_expand(false_queue, false_seen, false_reach)
true_set = set(true_reach)
for s in false_reach:
if s in true_set:
return s
return None
def _collect_loop_states(
states: dict[_StateKey, _StateBlock],
header: _StateKey,
is_exit: Callable[[_StateKey], bool],
latches: set[_StateKey],
) -> set[_StateKey]:
"""
Collect all state IDs that belong to a loop (reachable from header without leaving through
latches).
"""
loop_states: set[_StateKey] = set()
queue: deque[_StateKey] = deque([header])
while queue:
s = queue.popleft()
if is_exit(s) or s not in states or s in loop_states:
continue
loop_states.add(s)
if s in latches and s != header:
continue
queue.extend(states[s].transition.successors)
return loop_states
def _collect_internal_vars(
states: dict[_StateKey, _StateBlock],
loop_states: set[_StateKey],
state_var_key: _VarKey,
) -> set[_VarKey]:
"""
Identify internal dispatch variables within the loop. These are variables that are only
assigned constant values (integers, floats, strings, or booleans) or copies of other internal
variables. They are artifacts of the flattening and should be suppressed in output.
"""
candidates: set[_VarKey] = {state_var_key}
changed = True
while changed:
changed = False
for sid in loop_states:
block = states[sid]
for stmts in _all_statement_lists(block):
for stmt in stmts:
result = _get_simple_assignment(stmt)
if result is None:
continue
key, value = result
if key in candidates:
continue
if isinstance(value, Expression):
value = unwrap_parens(value)
if _unwrap_constant(value) is not None:
candidates.add(key)
changed = True
elif isinstance(value, Ps1Variable):
if _is_bool_literal(value) is not None:
candidates.add(key)
changed = True
elif _var_key(value) in candidates:
candidates.add(key)
changed = True
return candidates
def _all_statement_lists(block: _StateBlock) -> Generator[list[Statement], None, None]:
"""
Yield all statement lists within a state block (main statements plus conditional transition
prefixes).
"""
yield block.statements
t = block.transition
if isinstance(t, _ConditionalTransition):
yield t.true_prefix
yield t.false_prefix
def _update_env(
env: dict[_VarKey, _StateKey | bool],
key: _VarKey,
value: Node,
):
"""
Update the variable environment for an internal assignment.
"""
if isinstance(value, Expression):
value = unwrap_parens(value)
const_val = _unwrap_constant(value)
if const_val is not None:
env[key] = const_val
return
if isinstance(value, Ps1Variable):
bool_val = _is_bool_literal(value)
if bool_val is not None:
env[key] = bool_val
return
src_key = _var_key(value)
if src_key in env:
env[key] = env[src_key]
return
env.pop(key, None)
def _simulate_statements(
stmts: list[Statement],
env: dict[_VarKey, _StateKey | bool],
internal_vars: set[_VarKey],
output: list[Statement],
):
"""
Process a list of statements during simulation. Updates env for internal assignments, and
appends non-internal statements to output.
"""
for stmt in stmts:
result = _get_simple_assignment(stmt)
if result is not None:
key, value = result
if key in internal_vars:
_update_env(env, key, value)
continue
output.append(stmt)
def _simulate_arm(
states: dict[_StateKey, _StateBlock],
start: _StateKey,
header: _StateKey,
is_exit: Callable[[_StateKey], bool],
latches: set[_StateKey],
env: dict[_VarKey, _StateKey | bool],
internal_vars: set[_VarKey],
output: list[Statement],
) -> tuple[list[Statement], _StateKey | None] | None:
"""
Simulate execution of a single arm (branch) during loop unrolling. Follows the state graph
linearly, resolving conditions where possible. Returns (statements, next_state) or `None` if a
condition cannot be resolved.
"""
current: _StateKey | None = start
step_count = 0
while current is not None and step_count < _MAX_STATES * 2:
step_count += 1
if is_exit(current):
return (output, None)
if current not in states:
return None
if current == header:
return (output, header)
if current in latches:
cur_block = states[current]
_simulate_statements(cur_block.statements, env, internal_vars, output)
t = cur_block.transition
if isinstance(t, _LinearTransition) and t.target == header:
return (output, header)
return None
cur_block = states[current]
_simulate_statements(cur_block.statements, env, internal_vars, output)
t = cur_block.transition
if isinstance(t, _ExitTransition):
return (output, None)
if isinstance(t, _LinearTransition):
current = t.target
continue
if isinstance(t, _ConditionalTransition):
branch = _resolve_bool(t.condition, env)
if branch is True:
_simulate_statements(t.true_prefix, env, internal_vars, output)
current = t.true_target
elif branch is False:
_simulate_statements(t.false_prefix, env, internal_vars, output)
current = t.false_target
else:
return None
continue
return None
return None
def _seed_env_from_preamble(
states: dict[_StateKey, _StateBlock],
entry: _StateKey,
header: _StateKey,
is_exit: Callable[[_StateKey], bool],
internal_vars: set[_VarKey],
) -> dict[_VarKey, _StateKey | bool] | None:
"""
Walk the linear chain from entry to header, simulating assignments to build the initial
variable environment. Returns the env or `None` if the path from entry to header is not a
simple linear chain.
"""
env: dict[_VarKey, _StateKey | bool] = {}
discard: list[Statement] = []
current = entry
visited: set[_StateKey] = set()
while current != header:
if is_exit(current) or current not in states or current in visited:
return None
visited.add(current)
block = states[current]
_simulate_statements(block.statements, env, internal_vars, discard)
if isinstance(block.transition, _LinearTransition):
current = block.transition.target
else:
return None
return env
def _merge_next_state(
true_next: _StateKey | None,
false_next: _StateKey | None,
) -> _StateKey | None:
if true_next == false_next:
return true_next
if true_next is None:
return false_next
if false_next is None:
return true_next
return true_next
def _try_unroll_loop(
states: dict[_StateKey, _StateBlock],
entry: _StateKey,
header: _StateKey,
is_exit: Callable[[_StateKey], bool],
back_edges: dict[_StateKey, _StateKey],
loop_headers: set[_StateKey],
) -> tuple[list[Statement], set[_VarKey]] | None:
"""
Attempt to unroll a loop by symbolically executing it. Returns
(unrolled_statements, internal_vars)
if successful, or `None` if the loop cannot be fully resolved.
"""
block = states[header]
if not isinstance(block.transition, _ConditionalTransition):
return None
cond_trans = block.transition
latches: set[_StateKey] = set()
for latch, target in back_edges.items():
if target == header:
latches.add(latch)
loop_states = _collect_loop_states(states, header, is_exit, latches)
cond = cond_trans.condition
loop_var_key: _VarKey | None = None
cond_unwrapped = unwrap_parens(cond)
if isinstance(cond_unwrapped, Ps1BinaryExpression):
left = unwrap_parens(cond_unwrapped.left) if cond_unwrapped.left is not None else None
right = (
unwrap_parens(cond_unwrapped.right) if cond_unwrapped.right is not None else None
)
if isinstance(left, Ps1Variable):
loop_var_key = _var_key(left)
elif isinstance(right, Ps1Variable):
loop_var_key = _var_key(right)
if loop_var_key is None:
return None
internal_vars = _collect_internal_vars(states, loop_states, loop_var_key)
direction = _determine_loop_direction(states, cond_trans, header, is_exit, loop_headers)
if direction is None:
return None
body_entry = direction.body_entry
body_prefix_stmts = direction.prefix
seed = _seed_env_from_preamble(states, entry, header, is_exit, internal_vars)
env: dict[_VarKey, _StateKey | bool] = seed if seed is not None else {}
result: list[Statement] = []
for iteration in range(_MAX_UNROLL_ITERATIONS):
cond_result = _resolve_bool(cond, env)
if cond_result is False:
return (result, internal_vars)
if cond_result is None and iteration > 0:
return None
iteration_stmts: list[Statement] = []
_simulate_statements(block.statements, env, internal_vars, iteration_stmts)
_simulate_statements(body_prefix_stmts, env, internal_vars, iteration_stmts)
current: _StateKey | None = body_entry
step_count = 0
while current is not None and step_count < _MAX_STATES * 2:
step_count += 1
if is_exit(current):
result.extend(iteration_stmts)
return (result, internal_vars)
if current not in states:
return None
if current == header:
break
if current in latches:
cur_block = states[current]
_simulate_statements(
cur_block.statements, env, internal_vars, iteration_stmts,
)
t = cur_block.transition
if isinstance(t, _LinearTransition) and t.target == header:
break
return None
cur_block = states[current]
_simulate_statements(cur_block.statements, env, internal_vars, iteration_stmts)
t = cur_block.transition
if isinstance(t, _ExitTransition):
result.extend(iteration_stmts)
return (result, internal_vars)
if isinstance(t, _LinearTransition):
current = t.target
continue
if isinstance(t, _ConditionalTransition):
branch = _resolve_bool(t.condition, env)
if branch is True:
_simulate_statements(
t.true_prefix, env, internal_vars, iteration_stmts,
)
current = t.true_target
continue
if branch is False:
_simulate_statements(
t.false_prefix, env, internal_vars, iteration_stmts,
)
current = t.false_target
continue
true_stmts: list[Statement] = []
false_stmts: list[Statement] = []
true_env = dict(env)
false_env = dict(env)
_simulate_statements(
t.true_prefix, true_env, internal_vars, true_stmts,
)
_simulate_statements(
t.false_prefix, false_env, internal_vars, false_stmts,
)
true_arm_result = _simulate_arm(
states,
t.true_target,
header,
is_exit,
latches,
true_env,
internal_vars,
true_stmts,
)
false_arm_result = _simulate_arm(
states,
t.false_target,
header,
is_exit,
latches,
false_env,
internal_vars,
false_stmts,
)
if true_arm_result is None or false_arm_result is None:
return None
true_arm_body, true_next = true_arm_result
false_arm_body, false_next = false_arm_result
if_stmt = _build_if(t.condition, true_arm_body, false_arm_body)
if if_stmt is not None:
iteration_stmts.append(if_stmt)
for key in set(env.keys()) | set(true_env.keys()) | set(false_env.keys()):
tv = true_env.get(key)
fv = false_env.get(key)
if tv == fv and tv is not None:
env[key] = tv
else:
env.pop(key, None)
current = _merge_next_state(true_next, false_next)
continue
return None
result.extend(iteration_stmts)
return None
def _recover_structure(
states: dict[_StateKey, _StateBlock],
entry: _StateKey,
is_exit: Callable[[_StateKey], bool],
) -> list[Statement] | None:
"""
Walk the state graph from the entry state and emit the recovered AST. Returns `None` if the
structure cannot be recovered (e.g., irreducible control flow).
"""
back_edges = _find_back_edges(states, entry, is_exit)
loop_headers: set[_StateKey] = set(back_edges.values())
def _emit_arm(
start: _StateKey,
stop: _StateKey | None,
claimed: set[_StateKey],
) -> tuple[list[Statement], _StateKey | None] | None:
"""
Emit statements from start, stopping before the stop state (if given). Returns
(statements, next_state)
where next_state is the stop state or `None`.
"""
result: list[Statement] = []
current: _StateKey | None = start
while current is not None:
if is_exit(current):
return (result, None)
if current not in states:
return None
if stop is not None and current == stop:
return (result, current)
if current in claimed:
return (result, current)
if current in loop_headers:
loop_result = _emit_loop(current, claimed)
if loop_result is None:
return None
loop_stmts, loop_next, loop_internals = loop_result
if loop_internals is not None:
filtered: list[Statement] = []
for s in result:
assignment = _get_simple_assignment(s)
if assignment is None or assignment[0] not in loop_internals:
filtered.append(s)
result = filtered
result.extend(loop_stmts)
current = loop_next
continue
claimed.add(current)
block = states[current]
result.extend(block.statements)
transition = block.transition
if isinstance(transition, _ExitTransition):
return (result, None)
if isinstance(transition, _LinearTransition):
current = transition.target
continue
if isinstance(transition, _ConditionalTransition):
true_target = transition.true_target
false_target = transition.false_target
join = _find_join_point(
states, true_target, false_target, is_exit, loop_headers,
)
effective_stop = join if join is not None else stop
true_result = _emit_arm(true_target, effective_stop, claimed)
false_result = _emit_arm(false_target, effective_stop, claimed)
if true_result is None or false_result is None:
return None
true_stmts, true_next = true_result
false_stmts, false_next = false_result
true_stmts = list(transition.true_prefix) + true_stmts
false_stmts = list(transition.false_prefix) + false_stmts
if_stmt = _build_if(transition.condition, true_stmts, false_stmts)
if if_stmt is not None:
result.append(if_stmt)
current = _merge_next_state(true_next, false_next)
continue
return (result, None)
def _emit_loop(
header: _StateKey,
outer_claimed: set[_StateKey],
) -> tuple[list[Statement], _StateKey | None, set[_VarKey] | None] | None:
"""
Emit a while loop rooted at the given header state. First attempts to fully unroll the loop
via symbolic execution. Falls back to structural recovery if unrolling fails. Returns:
(statements, next_state, internal_vars_or_None)
"""
unrolled = _try_unroll_loop(
states, entry, header, is_exit, back_edges, loop_headers,
)
if unrolled is not None:
unrolled_stmts, unrolled_internals = unrolled
for sid in _collect_loop_states(
states,
header,
is_exit,
{latch for latch, target in back_edges.items() if target == header},
):
outer_claimed.add(sid)
return (unrolled_stmts, None, unrolled_internals)
block = states[header]
loop_cond: Expression
body_start: _StateKey
exit_target: _StateKey | None = None
if isinstance(block.transition, _ConditionalTransition):
cond_trans = block.transition
direction = _determine_loop_direction(
states, cond_trans, header, is_exit, loop_headers,
)
if direction is not None:
loop_cond = (
_negate_condition(cond_trans.condition) if direction.negated
else cond_trans.condition
)
body_start_stmts = list(block.statements) + direction.prefix
body_start = direction.body_entry
exit_target = direction.exit_target
else:
loop_cond = Ps1Variable(name='True', scope=Ps1ScopeModifier.NONE)
body_start_stmts = list(block.statements)
body_start = header
outer_claimed.add(header)
inner_claimed: set[_StateKey] = set()
inner_claimed.add(header)
body_result = _emit_arm(cond_trans.true_target, header, inner_claimed)
if body_result is None:
return None
body_stmts = body_start_stmts
if_true, _ = body_result
if_false_result = _emit_arm(cond_trans.false_target, header, inner_claimed)
if if_false_result is None:
return None
if_false, _ = if_false_result
true_body = list(cond_trans.true_prefix) + if_true
false_body = list(cond_trans.false_prefix) + if_false
if_stmt = _build_if(cond_trans.condition, true_body, false_body)
if if_stmt is not None:
body_stmts.append(if_stmt)
while_stmt = Ps1WhileLoop(
condition=loop_cond,
body=Block(body=body_stmts),
)
outer_claimed.update(inner_claimed)
return ([while_stmt], exit_target, None)
else:
loop_cond = Ps1Variable(name='True', scope=Ps1ScopeModifier.NONE)
body_start_stmts = list(block.statements)
if isinstance(block.transition, _LinearTransition):
body_start = block.transition.target
else:
return None
exit_target = None
outer_claimed.add(header)
inner_claimed = set()
inner_claimed.add(header)
body_result = _emit_arm(body_start, header, inner_claimed)
if body_result is None:
return None
body_stmts, _ = body_result
all_body = body_start_stmts + body_stmts
while_stmt = Ps1WhileLoop(
condition=loop_cond,
body=Block(body=all_body),
)
outer_claimed.update(inner_claimed)
next_state: _StateKey | None = None
if exit_target is not None and not is_exit(exit_target):
next_state = exit_target
return ([while_stmt], next_state, None)
claimed: set[_StateKey] = set()
result = _emit_arm(entry, None, claimed)
if result is None:
return None
stmts, _ = result
return stmts
class _LoopDirection(NamedTuple):
body_entry: _StateKey
exit_target: _StateKey | None
prefix: list[Statement]
negated: bool
def _determine_loop_direction(
states: dict[_StateKey, _StateBlock],
cond_trans: _ConditionalTransition,
header: _StateKey,
is_exit: Callable[[_StateKey], bool],
loop_headers: set[_StateKey],
) -> _LoopDirection | None:
"""
Determine which branch of a conditional loop header leads to the loop body
and which leads to the exit. Returns `None` when direction is ambiguous (both
branches reach the header, or neither does).
"""
true_target = cond_trans.true_target
false_target = cond_trans.false_target
true_is_exit = is_exit(true_target) or true_target not in states
false_is_exit = is_exit(false_target) or false_target not in states
true_reaches = _state_reaches(
states, true_target, header, is_exit, loop_headers - {header},
)
false_reaches = _state_reaches(
states, false_target, header, is_exit, loop_headers - {header},
)
if true_reaches and (not false_reaches or false_is_exit):
return _LoopDirection(true_target, false_target, list(cond_trans.true_prefix), False)
if false_reaches and (not true_reaches or true_is_exit):
return _LoopDirection(false_target, true_target, list(cond_trans.false_prefix), True)
if false_is_exit:
return _LoopDirection(true_target, false_target, list(cond_trans.true_prefix), False)
if true_is_exit:
return _LoopDirection(false_target, true_target, list(cond_trans.false_prefix), True)
return None
def _state_reaches(
states: dict[_StateKey, _StateBlock],
start: _StateKey,
target: _StateKey,
is_exit: Callable[[_StateKey], bool],
barriers: set[_StateKey],
) -> bool:
"""
Check if the state graph can reach target from start without crossing barriers.
"""
visited: set[_StateKey] = set()
queue: deque[_StateKey] = deque([start])
while queue:
s = queue.popleft()
if s == target:
return True
if is_exit(s) or s not in states or s in visited or s in barriers:
continue
visited.add(s)
queue.extend(states[s].transition.successors)
return False
class Ps1ControlFlowDeflattening(Transformer):
"""
Recover original control flow from control-flow-flattened scripts.
"""
def visit(self, node: Node):
for parent in list(node.walk()):
if inside_value_producing_context(parent):
continue
body = get_body(parent)
if body is None:
continue
self._try_deflatten_body(body, parent)
def _try_deflatten_body(self, body: list[Statement], parent: Node):
i = 0
while i < len(body):
stmt = body[i]
if not isinstance(stmt, Ps1WhileLoop):
i += 1
continue
match = _match_dispatcher(stmt)
if match is None:
i += 1
continue
init = _find_state_init(body, i, match.state_var_name, match.state_var_scope)
if init is None:
i += 1
continue
init_index, entry_state = init
is_exit = _make_exit_check(match.condition, match.state_var_name)
machine = _extract_state_machine(match, is_exit)
if machine is None:
i += 1
continue
if entry_state not in machine:
i += 1
continue
recovered = _recover_structure(machine, entry_state, is_exit)
if recovered is None:
i += 1
continue
recovered = [
s for s in recovered
if _is_state_assignment(s, match.state_var_name, match.state_var_scope) is None
]
for s in recovered:
s.parent = parent
body[init_index:i + 1] = recovered
self.mark_changed()
i += len(recovered)
Classes
class Ps1ControlFlowDeflattening-
Recover original control flow from control-flow-flattened scripts.
Expand source code Browse git
class Ps1ControlFlowDeflattening(Transformer): """ Recover original control flow from control-flow-flattened scripts. """ def visit(self, node: Node): for parent in list(node.walk()): if inside_value_producing_context(parent): continue body = get_body(parent) if body is None: continue self._try_deflatten_body(body, parent) def _try_deflatten_body(self, body: list[Statement], parent: Node): i = 0 while i < len(body): stmt = body[i] if not isinstance(stmt, Ps1WhileLoop): i += 1 continue match = _match_dispatcher(stmt) if match is None: i += 1 continue init = _find_state_init(body, i, match.state_var_name, match.state_var_scope) if init is None: i += 1 continue init_index, entry_state = init is_exit = _make_exit_check(match.condition, match.state_var_name) machine = _extract_state_machine(match, is_exit) if machine is None: i += 1 continue if entry_state not in machine: i += 1 continue recovered = _recover_structure(machine, entry_state, is_exit) if recovered is None: i += 1 continue recovered = [ s for s in recovered if _is_state_assignment(s, match.state_var_name, match.state_var_scope) is None ] for s in recovered: s.parent = parent body[init_index:i + 1] = recovered self.mark_changed() i += len(recovered)Ancestors
Methods
def visit(self, node)-
Expand source code Browse git
def visit(self, node: Node): for parent in list(node.walk()): if inside_value_producing_context(parent): continue body = get_body(parent) if body is None: continue self._try_deflatten_body(body, parent)