Module refinery.lib.scripts.ps1.deobfuscation.deadcode
Eliminate dead code from PowerShell scripts after constant folding.
Expand source code Browse git
"""
Eliminate dead code from PowerShell scripts after constant folding.
"""
from __future__ import annotations
from refinery.lib.scripts import Block, Expression, Node, Statement, Transformer
from refinery.lib.scripts.ps1.deobfuscation.data import COMPARISON_OPS
from refinery.lib.scripts.ps1.deobfuscation.helpers import (
get_body,
inside_value_producing_context,
is_builtin_variable,
is_truthy,
switch_matches,
unwrap_integer,
unwrap_parens,
)
from refinery.lib.scripts.ps1.model import (
Ps1AssignmentExpression,
Ps1BinaryExpression,
Ps1BreakStatement,
Ps1ContinueStatement,
Ps1DoLoop,
Ps1ExpressionStatement,
Ps1ForLoop,
Ps1IfStatement,
Ps1IntegerLiteral,
Ps1ParenExpression,
Ps1RealLiteral,
Ps1ScopeModifier,
Ps1Script,
Ps1StringLiteral,
Ps1SwitchStatement,
Ps1UnaryExpression,
Ps1Variable,
Ps1WhileLoop,
)
def _evaluate_for_condition(node: Ps1ForLoop) -> bool | None:
"""
Try to evaluate a for-loop condition at loop entry by substituting the initial value of the
loop variable into the comparison. Returns the boolean result, or `None` if the pattern does not
match.
"""
init = node.initializer
cond = node.condition
if not isinstance(init, Ps1AssignmentExpression) or init.operator != '=':
return None
if not isinstance(init.target, Ps1Variable):
return None
init_val = unwrap_integer(init.value)
if init_val is None:
return None
if not isinstance(cond, Ps1BinaryExpression):
return None
op_fn = COMPARISON_OPS.get(cond.operator.lower())
if op_fn is None:
return None
var_name = init.target.name.lower()
var_scope = init.target.scope
left_val = _resolve_side(cond.left, var_name, var_scope, init_val.value)
right_val = _resolve_side(cond.right, var_name, var_scope, init_val.value)
if left_val is None or right_val is None:
return None
return bool(op_fn(left_val, right_val))
def _resolve_side(
node, var_name: str, var_scope: Ps1ScopeModifier, init_val: int,
) -> int | None:
"""
Resolve one side of a for-loop condition to an integer: if the node is the loop variable,
return the initial value; if it is a constant integer, return that; otherwise return `None`.
"""
node = unwrap_parens(node) if isinstance(node, Expression) else node
if (
isinstance(node, Ps1Variable)
and node.name.lower() == var_name
and node.scope == var_scope
):
return init_val
result = unwrap_integer(node)
return result.value if result is not None else None
def _body_breaks_unconditionally(body: list[Statement]) -> bool:
"""
Return `True` if the last statement in the body is an unlabeled break and the body contains no
continue statements at any nesting depth. Such a loop body executes exactly once.
"""
if not body:
return False
last = body[-1]
if not isinstance(last, Ps1BreakStatement) or last.label is not None:
return False
for stmt in body[:-1]:
for node in stmt.walk():
if isinstance(node, Ps1ContinueStatement):
return False
return True
_NO_LITERAL = object()
def _switch_literal(node):
"""
Extract the constant `int`/`str`/`bool` value a switch value or clause condition compares with,
or `_NO_LITERAL` when it is not a compile-time constant.
"""
node = unwrap_parens(node)
if isinstance(node, (Ps1IntegerLiteral, Ps1RealLiteral, Ps1StringLiteral)):
return node.value
if is_builtin_variable(node, {'true'}):
return True
if is_builtin_variable(node, {'false'}):
return False
return _NO_LITERAL
def _switch_clause_body(body: list[Statement]) -> tuple[list[Statement], bool] | None:
"""
Return the statements of a matched switch clause together with a flag indicating whether the
clause terminates the switch (a trailing `break`). Returns `None` when the body contains a
top-level `break`/`continue` that is not a single trailing `break`, since inlining it would
retarget the jump to an enclosing loop.
"""
stmts = list(body)
stop = False
if stmts and isinstance(stmts[-1], Ps1BreakStatement):
stmts = stmts[:-1]
stop = True
for stmt in stmts:
if isinstance(stmt, (Ps1BreakStatement, Ps1ContinueStatement)):
return None
return stmts, stop
def _is_pure_constant(node) -> bool:
"""
Return `True` when an expression is a side-effect-free constant that can be removed as a
standalone statement. Only matches numeric literals and the built-in constants `$Null`,
`$True`, and `$False` — string literals are excluded because they may represent intentional
pipeline output.
"""
if isinstance(node, (Ps1IntegerLiteral, Ps1RealLiteral)):
return True
if is_builtin_variable(node):
return True
if isinstance(node, Ps1ParenExpression):
return _is_pure_constant(node.expression)
if isinstance(node, Ps1UnaryExpression) and node.operator in ('+', '-'):
return _is_pure_constant(node.operand)
return False
class Ps1DeadCodeElimination(Transformer):
"""
Remove unreachable code guarded by constant boolean conditions and resolve switch statements
on constant values.
"""
def visit(self, node: Node):
for parent in list(node.walk()):
if inside_value_producing_context(parent):
continue
body = get_body(parent)
if body is None:
continue
new_body = self._prune_body(body, isinstance(parent, Ps1Script))
if new_body is not body:
body.clear()
body.extend(new_body)
for stmt in new_body:
stmt.parent = parent
self.mark_changed()
def _prune_body(
self, body: list[Statement], is_script_level: bool = False,
) -> list[Statement]:
result: list[Statement] = []
changed = False
prune_constants = not is_script_level or any(
not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression))
for s in body
)
for stmt in body:
if (
prune_constants
and isinstance(stmt, Ps1ExpressionStatement)
and _is_pure_constant(stmt.expression)
):
changed = True
continue
replacement = self._try_prune(stmt)
if replacement is not None:
result.extend(replacement)
changed = True
else:
result.append(stmt)
return result if changed else body
def _try_prune(self, stmt: Statement) -> list[Statement] | None:
if isinstance(stmt, Ps1WhileLoop):
return self._prune_while(stmt)
if isinstance(stmt, Ps1DoLoop):
return self._prune_do_loop(stmt)
if isinstance(stmt, Ps1ForLoop):
return self._prune_for(stmt)
if isinstance(stmt, Ps1IfStatement):
return self._prune_if(stmt)
if isinstance(stmt, Ps1SwitchStatement):
return self._prune_switch(stmt)
return None
@staticmethod
def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None:
truth = is_truthy(node.condition)
if truth is False:
return []
if node.body is not None and _body_breaks_unconditionally(node.body.body):
body = list(node.body.body[:-1])
if truth is True or node.condition is None:
return body
return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])]
return None
@staticmethod
def _prune_do_loop(node: Ps1DoLoop) -> list[Statement] | None:
if node.body is not None:
trivially_exits = (
is_truthy(node.condition) is True if node.is_until
else is_truthy(node.condition) is False
)
if trivially_exits:
return list(node.body.body)
if _body_breaks_unconditionally(node.body.body):
return list(node.body.body[:-1])
return None
@staticmethod
def _prune_for(node: Ps1ForLoop) -> list[Statement] | None:
truth = _evaluate_for_condition(node)
if truth is None:
truth = is_truthy(node.condition)
if truth is False:
result: list[Statement] = []
if node.initializer is not None:
result.append(Ps1ExpressionStatement(expression=node.initializer))
return result
if node.body is not None and _body_breaks_unconditionally(node.body.body):
result = []
if node.initializer is not None:
result.append(Ps1ExpressionStatement(expression=node.initializer))
body = list(node.body.body[:-1])
if truth is True or node.condition is None:
result.extend(body)
else:
result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))]))
return result
return None
@staticmethod
def _prune_if(node: Ps1IfStatement) -> list[Statement] | None:
kept_clauses: list[tuple] = []
for condition, block in node.clauses:
truth = is_truthy(condition)
if truth is True:
return list(block.body)
if truth is False:
continue
kept_clauses.append((condition, block))
kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:])
break
else:
if node.else_block is not None:
return list(node.else_block.body)
return []
if len(kept_clauses) == len(node.clauses):
return None
node.clauses[:] = kept_clauses
return [node]
@staticmethod
def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None:
if node.regex or node.wildcard or node.file:
return None
value = _switch_literal(node.value)
if value is _NO_LITERAL:
return None
default_body: list[Statement] | None = None
result: list[Statement] = []
matched = False
for condition, block in node.clauses:
if condition is None:
default_body = block.body
continue
cond_val = _switch_literal(condition)
if cond_val is _NO_LITERAL:
# A non-constant clause condition might match at runtime; cannot resolve statically.
return None
if switch_matches(value, cond_val, case_sensitive=node.case_sensitive):
body = _switch_clause_body(block.body)
if body is None:
return None
stmts, stop = body
result.extend(stmts)
matched = True
if stop:
return result
if matched:
return result
if default_body is not None:
body = _switch_clause_body(default_body)
if body is None:
return None
return body[0]
return []
Classes
class Ps1DeadCodeElimination-
Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values.
Expand source code Browse git
class Ps1DeadCodeElimination(Transformer): """ Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values. """ def visit(self, node: Node): for parent in list(node.walk()): if inside_value_producing_context(parent): continue body = get_body(parent) if body is None: continue new_body = self._prune_body(body, isinstance(parent, Ps1Script)) if new_body is not body: body.clear() body.extend(new_body) for stmt in new_body: stmt.parent = parent self.mark_changed() def _prune_body( self, body: list[Statement], is_script_level: bool = False, ) -> list[Statement]: result: list[Statement] = [] changed = False prune_constants = not is_script_level or any( not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression)) for s in body ) for stmt in body: if ( prune_constants and isinstance(stmt, Ps1ExpressionStatement) and _is_pure_constant(stmt.expression) ): changed = True continue replacement = self._try_prune(stmt) if replacement is not None: result.extend(replacement) changed = True else: result.append(stmt) return result if changed else body def _try_prune(self, stmt: Statement) -> list[Statement] | None: if isinstance(stmt, Ps1WhileLoop): return self._prune_while(stmt) if isinstance(stmt, Ps1DoLoop): return self._prune_do_loop(stmt) if isinstance(stmt, Ps1ForLoop): return self._prune_for(stmt) if isinstance(stmt, Ps1IfStatement): return self._prune_if(stmt) if isinstance(stmt, Ps1SwitchStatement): return self._prune_switch(stmt) return None @staticmethod def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None: truth = is_truthy(node.condition) if truth is False: return [] if node.body is not None and _body_breaks_unconditionally(node.body.body): body = list(node.body.body[:-1]) if truth is True or node.condition is None: return body return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])] return None @staticmethod def _prune_do_loop(node: Ps1DoLoop) -> list[Statement] | None: if node.body is not None: trivially_exits = ( is_truthy(node.condition) is True if node.is_until else is_truthy(node.condition) is False ) if trivially_exits: return list(node.body.body) if _body_breaks_unconditionally(node.body.body): return list(node.body.body[:-1]) return None @staticmethod def _prune_for(node: Ps1ForLoop) -> list[Statement] | None: truth = _evaluate_for_condition(node) if truth is None: truth = is_truthy(node.condition) if truth is False: result: list[Statement] = [] if node.initializer is not None: result.append(Ps1ExpressionStatement(expression=node.initializer)) return result if node.body is not None and _body_breaks_unconditionally(node.body.body): result = [] if node.initializer is not None: result.append(Ps1ExpressionStatement(expression=node.initializer)) body = list(node.body.body[:-1]) if truth is True or node.condition is None: result.extend(body) else: result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))])) return result return None @staticmethod def _prune_if(node: Ps1IfStatement) -> list[Statement] | None: kept_clauses: list[tuple] = [] for condition, block in node.clauses: truth = is_truthy(condition) if truth is True: return list(block.body) if truth is False: continue kept_clauses.append((condition, block)) kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:]) break else: if node.else_block is not None: return list(node.else_block.body) return [] if len(kept_clauses) == len(node.clauses): return None node.clauses[:] = kept_clauses return [node] @staticmethod def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None: if node.regex or node.wildcard or node.file: return None value = _switch_literal(node.value) if value is _NO_LITERAL: return None default_body: list[Statement] | None = None result: list[Statement] = [] matched = False for condition, block in node.clauses: if condition is None: default_body = block.body continue cond_val = _switch_literal(condition) if cond_val is _NO_LITERAL: # A non-constant clause condition might match at runtime; cannot resolve statically. return None if switch_matches(value, cond_val, case_sensitive=node.case_sensitive): body = _switch_clause_body(block.body) if body is None: return None stmts, stop = body result.extend(stmts) matched = True if stop: return result if matched: return result if default_body is not None: body = _switch_clause_body(default_body) if body is None: return None return body[0] return []Ancestors
Methods
def visit(self, node)-
Expand source code Browse git
def visit(self, node: Node): for parent in list(node.walk()): if inside_value_producing_context(parent): continue body = get_body(parent) if body is None: continue new_body = self._prune_body(body, isinstance(parent, Ps1Script)) if new_body is not body: body.clear() body.extend(new_body) for stmt in new_body: stmt.parent = parent self.mark_changed()
Inherited members