Module refinery.lib.scripts.ps1.deobfuscation.deadcode
Eliminate dead code from PowerShell scripts after constant folding.
Expand source code Browse git
"""
Eliminate dead code from PowerShell scripts after constant folding.
"""
from __future__ import annotations
from refinery.lib.scripts import Block, Expression, Node, Statement, Transformer
from refinery.lib.scripts.ps1.deobfuscation.data import COMPARISON_OPS
from refinery.lib.scripts.ps1.deobfuscation.helpers import (
get_body,
inside_value_producing_context,
is_builtin_variable,
is_truthy,
unwrap_integer,
unwrap_parens,
)
from refinery.lib.scripts.ps1.model import (
Ps1AssignmentExpression,
Ps1BinaryExpression,
Ps1BreakStatement,
Ps1ContinueStatement,
Ps1DoLoop,
Ps1ExpressionStatement,
Ps1ForLoop,
Ps1IfStatement,
Ps1IntegerLiteral,
Ps1ParenExpression,
Ps1RealLiteral,
Ps1ScopeModifier,
Ps1Script,
Ps1StringLiteral,
Ps1SwitchStatement,
Ps1UnaryExpression,
Ps1Variable,
Ps1WhileLoop,
)
def _evaluate_for_condition(node: Ps1ForLoop) -> bool | None:
"""
Try to evaluate a for-loop condition at loop entry by substituting the initial value of the
loop variable into the comparison. Returns the boolean result, or `None` if the pattern does not
match.
"""
init = node.initializer
cond = node.condition
if not isinstance(init, Ps1AssignmentExpression) or init.operator != '=':
return None
if not isinstance(init.target, Ps1Variable):
return None
init_val = unwrap_integer(init.value)
if init_val is None:
return None
if not isinstance(cond, Ps1BinaryExpression):
return None
op_fn = COMPARISON_OPS.get(cond.operator.lower())
if op_fn is None:
return None
var_name = init.target.name.lower()
var_scope = init.target.scope
left_val = _resolve_side(cond.left, var_name, var_scope, init_val.value)
right_val = _resolve_side(cond.right, var_name, var_scope, init_val.value)
if left_val is None or right_val is None:
return None
return bool(op_fn(left_val, right_val))
def _resolve_side(
node, var_name: str, var_scope: Ps1ScopeModifier, init_val: int,
) -> int | None:
"""
Resolve one side of a for-loop condition to an integer: if the node is the loop variable,
return the initial value; if it is a constant integer, return that; otherwise return `None`.
"""
node = unwrap_parens(node) if isinstance(node, Expression) else node
if (
isinstance(node, Ps1Variable)
and node.name.lower() == var_name
and node.scope == var_scope
):
return init_val
result = unwrap_integer(node)
return result.value if result is not None else None
def _body_breaks_unconditionally(body: list[Statement]) -> bool:
"""
Return `True` if the last statement in the body is an unlabeled break and the body contains no
continue statements at any nesting depth. Such a loop body executes exactly once.
"""
if not body:
return False
last = body[-1]
if not isinstance(last, Ps1BreakStatement) or last.label is not None:
return False
for stmt in body[:-1]:
for node in stmt.walk():
if isinstance(node, Ps1ContinueStatement):
return False
return True
def _is_pure_constant(node) -> bool:
"""
Return `True` when an expression is a side-effect-free constant that can be removed as a
standalone statement. Only matches numeric literals and the built-in constants `$Null`,
`$True`, and `$False` — string literals are excluded because they may represent intentional
pipeline output.
"""
if isinstance(node, (Ps1IntegerLiteral, Ps1RealLiteral)):
return True
if is_builtin_variable(node):
return True
if isinstance(node, Ps1ParenExpression):
return _is_pure_constant(node.expression)
if isinstance(node, Ps1UnaryExpression) and node.operator in ('+', '-'):
return _is_pure_constant(node.operand)
return False
class Ps1DeadCodeElimination(Transformer):
"""
Remove unreachable code guarded by constant boolean conditions and resolve switch statements
on constant values.
"""
def visit(self, node: Node):
for parent in list(node.walk()):
if inside_value_producing_context(parent):
continue
body = get_body(parent)
if body is None:
continue
new_body = self._prune_body(body, isinstance(parent, Ps1Script))
if new_body is not body:
body.clear()
body.extend(new_body)
for stmt in new_body:
stmt.parent = parent
self.mark_changed()
def _prune_body(
self, body: list[Statement], is_script_level: bool = False,
) -> list[Statement]:
result: list[Statement] = []
changed = False
prune_constants = not is_script_level or any(
not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression))
for s in body
)
for stmt in body:
if (
prune_constants
and isinstance(stmt, Ps1ExpressionStatement)
and _is_pure_constant(stmt.expression)
):
changed = True
continue
replacement = self._try_prune(stmt)
if replacement is not None:
result.extend(replacement)
changed = True
else:
result.append(stmt)
return result if changed else body
def _try_prune(self, stmt: Statement) -> list[Statement] | None:
if isinstance(stmt, Ps1WhileLoop):
return self._prune_while(stmt)
if isinstance(stmt, Ps1DoLoop):
return self._prune_do_loop(stmt)
if isinstance(stmt, Ps1ForLoop):
return self._prune_for(stmt)
if isinstance(stmt, Ps1IfStatement):
return self._prune_if(stmt)
if isinstance(stmt, Ps1SwitchStatement):
return self._prune_switch(stmt)
return None
@staticmethod
def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None:
truth = is_truthy(node.condition)
if truth is False:
return []
if node.body is not None and _body_breaks_unconditionally(node.body.body):
body = list(node.body.body[:-1])
if truth is True or node.condition is None:
return body
return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])]
return None
@staticmethod
def _prune_do_loop(node: Ps1DoLoop) -> list[Statement] | None:
if node.body is not None:
trivially_exits = (
is_truthy(node.condition) is True if node.is_until
else is_truthy(node.condition) is False
)
if trivially_exits:
return list(node.body.body)
if _body_breaks_unconditionally(node.body.body):
return list(node.body.body[:-1])
return None
@staticmethod
def _prune_for(node: Ps1ForLoop) -> list[Statement] | None:
truth = _evaluate_for_condition(node)
if truth is None:
truth = is_truthy(node.condition)
if truth is False:
result: list[Statement] = []
if node.initializer is not None:
result.append(Ps1ExpressionStatement(expression=node.initializer))
return result
if node.body is not None and _body_breaks_unconditionally(node.body.body):
result = []
if node.initializer is not None:
result.append(Ps1ExpressionStatement(expression=node.initializer))
body = list(node.body.body[:-1])
if truth is True or node.condition is None:
result.extend(body)
else:
result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))]))
return result
return None
@staticmethod
def _prune_if(node: Ps1IfStatement) -> list[Statement] | None:
kept_clauses: list[tuple] = []
for condition, block in node.clauses:
truth = is_truthy(condition)
if truth is True:
return list(block.body)
if truth is False:
continue
kept_clauses.append((condition, block))
kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:])
break
else:
if node.else_block is not None:
return list(node.else_block.body)
return []
if len(kept_clauses) == len(node.clauses):
return None
node.clauses[:] = kept_clauses
return None
@staticmethod
def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None:
if node.regex or node.wildcard or node.exact or node.file:
return None
value = node.value
if isinstance(value, Ps1IntegerLiteral):
target_int = value.value
target_str = None
elif isinstance(value, Ps1StringLiteral):
target_str = value.value.lower()
target_int = None
else:
return None
default_body: list[Statement] | None = None
for condition, block in node.clauses:
if condition is None:
default_body = list(block.body)
continue
if target_int is not None and isinstance(condition, Ps1IntegerLiteral):
if condition.value == target_int:
return list(block.body)
elif target_str is not None and isinstance(condition, Ps1StringLiteral):
if condition.value.lower() == target_str:
return list(block.body)
elif target_int is not None and isinstance(condition, Ps1StringLiteral):
try:
if int(condition.value) == target_int:
return list(block.body)
except ValueError:
pass
elif target_str is not None and isinstance(condition, Ps1IntegerLiteral):
try:
if int(target_str) == condition.value:
return list(block.body)
except ValueError:
pass
if default_body is not None:
return default_body
return []
Classes
class Ps1DeadCodeElimination-
Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values.
Expand source code Browse git
class Ps1DeadCodeElimination(Transformer): """ Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values. """ def visit(self, node: Node): for parent in list(node.walk()): if inside_value_producing_context(parent): continue body = get_body(parent) if body is None: continue new_body = self._prune_body(body, isinstance(parent, Ps1Script)) if new_body is not body: body.clear() body.extend(new_body) for stmt in new_body: stmt.parent = parent self.mark_changed() def _prune_body( self, body: list[Statement], is_script_level: bool = False, ) -> list[Statement]: result: list[Statement] = [] changed = False prune_constants = not is_script_level or any( not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression)) for s in body ) for stmt in body: if ( prune_constants and isinstance(stmt, Ps1ExpressionStatement) and _is_pure_constant(stmt.expression) ): changed = True continue replacement = self._try_prune(stmt) if replacement is not None: result.extend(replacement) changed = True else: result.append(stmt) return result if changed else body def _try_prune(self, stmt: Statement) -> list[Statement] | None: if isinstance(stmt, Ps1WhileLoop): return self._prune_while(stmt) if isinstance(stmt, Ps1DoLoop): return self._prune_do_loop(stmt) if isinstance(stmt, Ps1ForLoop): return self._prune_for(stmt) if isinstance(stmt, Ps1IfStatement): return self._prune_if(stmt) if isinstance(stmt, Ps1SwitchStatement): return self._prune_switch(stmt) return None @staticmethod def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None: truth = is_truthy(node.condition) if truth is False: return [] if node.body is not None and _body_breaks_unconditionally(node.body.body): body = list(node.body.body[:-1]) if truth is True or node.condition is None: return body return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])] return None @staticmethod def _prune_do_loop(node: Ps1DoLoop) -> list[Statement] | None: if node.body is not None: trivially_exits = ( is_truthy(node.condition) is True if node.is_until else is_truthy(node.condition) is False ) if trivially_exits: return list(node.body.body) if _body_breaks_unconditionally(node.body.body): return list(node.body.body[:-1]) return None @staticmethod def _prune_for(node: Ps1ForLoop) -> list[Statement] | None: truth = _evaluate_for_condition(node) if truth is None: truth = is_truthy(node.condition) if truth is False: result: list[Statement] = [] if node.initializer is not None: result.append(Ps1ExpressionStatement(expression=node.initializer)) return result if node.body is not None and _body_breaks_unconditionally(node.body.body): result = [] if node.initializer is not None: result.append(Ps1ExpressionStatement(expression=node.initializer)) body = list(node.body.body[:-1]) if truth is True or node.condition is None: result.extend(body) else: result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))])) return result return None @staticmethod def _prune_if(node: Ps1IfStatement) -> list[Statement] | None: kept_clauses: list[tuple] = [] for condition, block in node.clauses: truth = is_truthy(condition) if truth is True: return list(block.body) if truth is False: continue kept_clauses.append((condition, block)) kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:]) break else: if node.else_block is not None: return list(node.else_block.body) return [] if len(kept_clauses) == len(node.clauses): return None node.clauses[:] = kept_clauses return None @staticmethod def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None: if node.regex or node.wildcard or node.exact or node.file: return None value = node.value if isinstance(value, Ps1IntegerLiteral): target_int = value.value target_str = None elif isinstance(value, Ps1StringLiteral): target_str = value.value.lower() target_int = None else: return None default_body: list[Statement] | None = None for condition, block in node.clauses: if condition is None: default_body = list(block.body) continue if target_int is not None and isinstance(condition, Ps1IntegerLiteral): if condition.value == target_int: return list(block.body) elif target_str is not None and isinstance(condition, Ps1StringLiteral): if condition.value.lower() == target_str: return list(block.body) elif target_int is not None and isinstance(condition, Ps1StringLiteral): try: if int(condition.value) == target_int: return list(block.body) except ValueError: pass elif target_str is not None and isinstance(condition, Ps1IntegerLiteral): try: if int(target_str) == condition.value: return list(block.body) except ValueError: pass if default_body is not None: return default_body return []Ancestors
Methods
def visit(self, node)-
Expand source code Browse git
def visit(self, node: Node): for parent in list(node.walk()): if inside_value_producing_context(parent): continue body = get_body(parent) if body is None: continue new_body = self._prune_body(body, isinstance(parent, Ps1Script)) if new_body is not body: body.clear() body.extend(new_body) for stmt in new_body: stmt.parent = parent self.mark_changed()