Module refinery.lib.scripts.ps1.deobfuscation.deadcode
Eliminate dead code from PowerShell scripts after constant folding.
Expand source code Browse git
"""
Eliminate dead code from PowerShell scripts after constant folding.
"""
from __future__ import annotations
from refinery.lib.scripts import Block, Node, Statement, Transformer
from refinery.lib.scripts.ps1.deobfuscation._helpers import _get_body
from refinery.lib.scripts.ps1.model import (
Ps1AssignmentExpression,
Ps1BinaryExpression,
Ps1BreakStatement,
Ps1ContinueStatement,
Ps1DoUntilLoop,
Ps1DoWhileLoop,
Ps1ExpressionStatement,
Ps1ForLoop,
Ps1IfStatement,
Ps1IntegerLiteral,
Ps1ParenExpression,
Ps1RealLiteral,
Ps1ScopeModifier,
Ps1Script,
Ps1StringLiteral,
Ps1SubExpression,
Ps1SwitchStatement,
Ps1UnaryExpression,
Ps1Variable,
Ps1WhileLoop,
)
def _is_truthy(node) -> bool | None:
"""
Determine the boolean truth value of a constant expression using PowerShell
semantics. Returns `None` for non-constant or unrecognized expressions.
"""
while isinstance(node, Ps1ParenExpression):
node = node.expression
if node is None:
return None
if isinstance(node, Ps1Variable) and node.scope == node.scope.NONE:
lower = node.name.lower()
if lower == 'true':
return True
if lower in ('false', 'null'):
return False
return None
if isinstance(node, Ps1IntegerLiteral):
return node.value != 0
if isinstance(node, Ps1RealLiteral):
return node.value != 0.0
if isinstance(node, Ps1StringLiteral):
return len(node.value) > 0
if isinstance(node, Ps1UnaryExpression) and node.operator == '-':
return _is_truthy(node.operand)
return None
def _unwrap_integer(node) -> int | None:
"""
Extract a plain integer value from a constant expression, or return None.
"""
while isinstance(node, Ps1ParenExpression):
node = node.expression
if isinstance(node, Ps1IntegerLiteral):
return node.value
if (
isinstance(node, Ps1Variable)
and node.scope == Ps1ScopeModifier.NONE
and node.name.lower() == 'null'
):
return 0
if isinstance(node, Ps1UnaryExpression) and node.operator == '-':
inner = node.operand
while isinstance(inner, Ps1ParenExpression):
inner = inner.expression
if isinstance(inner, Ps1IntegerLiteral):
return -inner.value
return None
_COMPARISON_OPS = {
'-eq': int.__eq__,
'-ne': int.__ne__,
'-lt': int.__lt__,
'-le': int.__le__,
'-gt': int.__gt__,
'-ge': int.__ge__,
}
def _evaluate_for_condition(node: Ps1ForLoop) -> bool | None:
"""
Try to evaluate a for-loop condition at loop entry by substituting the initial value of the
loop variable into the comparison. Returns the boolean result, or None if the pattern does not
match.
"""
init = node.initializer
cond = node.condition
if not isinstance(init, Ps1AssignmentExpression) or init.operator != '=':
return None
if not isinstance(init.target, Ps1Variable):
return None
init_val = _unwrap_integer(init.value)
if init_val is None:
return None
if not isinstance(cond, Ps1BinaryExpression):
return None
op_fn = _COMPARISON_OPS.get(cond.operator.lower())
if op_fn is None:
return None
var_name = init.target.name.lower()
var_scope = init.target.scope
left_val = _resolve_side(cond.left, var_name, var_scope, init_val)
right_val = _resolve_side(cond.right, var_name, var_scope, init_val)
if left_val is None or right_val is None:
return None
return bool(op_fn(left_val, right_val))
def _resolve_side(
node, var_name: str, var_scope: Ps1ScopeModifier, init_val: int,
) -> int | None:
"""
Resolve one side of a for-loop condition to an integer: if the node is the loop variable,
return the initial value; if it is a constant integer, return that; otherwise return None.
"""
while isinstance(node, Ps1ParenExpression):
node = node.expression
if (
isinstance(node, Ps1Variable)
and node.name.lower() == var_name
and node.scope == var_scope
):
return init_val
return _unwrap_integer(node)
def _body_breaks_unconditionally(body: list[Statement]) -> bool:
"""
Return True if the last statement in the body is an unlabeled break and the body contains no
continue statements at any nesting depth. Such a loop body executes exactly once.
"""
if not body:
return False
last = body[-1]
if not isinstance(last, Ps1BreakStatement) or last.label is not None:
return False
for stmt in body[:-1]:
for node in stmt.walk():
if isinstance(node, Ps1ContinueStatement):
return False
return True
def _is_pure_constant(node) -> bool:
"""
Return True when an expression is a side-effect-free constant that can be removed as a
standalone statement. Only matches numeric literals and the built-in constants `$Null`,
`$True`, and `$False` — string literals are excluded because they may represent intentional
pipeline output.
"""
if isinstance(node, (Ps1IntegerLiteral, Ps1RealLiteral)):
return True
if isinstance(node, Ps1Variable) and node.scope == Ps1ScopeModifier.NONE:
return node.name.lower() in ('null', 'true', 'false')
if isinstance(node, Ps1ParenExpression):
return _is_pure_constant(node.expression)
if isinstance(node, Ps1UnaryExpression) and node.operator in ('+', '-'):
return _is_pure_constant(node.operand)
return False
class Ps1DeadCodeElimination(Transformer):
"""
Remove unreachable code guarded by constant boolean conditions and resolve switch statements
on constant values.
"""
def visit(self, node: Node):
for parent in list(node.walk()):
if isinstance(parent, Ps1SubExpression):
continue
body = _get_body(parent)
if body is None:
continue
new_body = self._prune_body(body, isinstance(parent, Ps1Script))
if new_body is not body:
body.clear()
body.extend(new_body)
for stmt in new_body:
stmt.parent = parent
self.mark_changed()
def _prune_body(
self, body: list[Statement], is_script_level: bool = False,
) -> list[Statement]:
result: list[Statement] = []
changed = False
prune_constants = not is_script_level or any(
not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression))
for s in body
)
for stmt in body:
if (
prune_constants
and isinstance(stmt, Ps1ExpressionStatement)
and _is_pure_constant(stmt.expression)
):
changed = True
continue
replacement = self._try_prune(stmt)
if replacement is not None:
result.extend(replacement)
changed = True
else:
result.append(stmt)
return result if changed else body
def _try_prune(self, stmt: Statement) -> list[Statement] | None:
if isinstance(stmt, Ps1WhileLoop):
return self._prune_while(stmt)
if isinstance(stmt, Ps1DoWhileLoop):
return self._prune_do_while(stmt)
if isinstance(stmt, Ps1DoUntilLoop):
return self._prune_do_until(stmt)
if isinstance(stmt, Ps1ForLoop):
return self._prune_for(stmt)
if isinstance(stmt, Ps1IfStatement):
return self._prune_if(stmt)
if isinstance(stmt, Ps1SwitchStatement):
return self._prune_switch(stmt)
return None
@staticmethod
def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None:
truth = _is_truthy(node.condition)
if truth is False:
return []
if node.body is not None and _body_breaks_unconditionally(node.body.body):
body = list(node.body.body[:-1])
if truth is True or node.condition is None:
return body
return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])]
return None
@staticmethod
def _prune_do_while(node: Ps1DoWhileLoop) -> list[Statement] | None:
if node.body is not None:
if _is_truthy(node.condition) is False:
return list(node.body.body)
if _body_breaks_unconditionally(node.body.body):
return list(node.body.body[:-1])
return None
@staticmethod
def _prune_do_until(node: Ps1DoUntilLoop) -> list[Statement] | None:
if node.body is not None:
if _is_truthy(node.condition) is True:
return list(node.body.body)
if _body_breaks_unconditionally(node.body.body):
return list(node.body.body[:-1])
return None
@staticmethod
def _prune_for(node: Ps1ForLoop) -> list[Statement] | None:
truth = _evaluate_for_condition(node)
if truth is None:
truth = _is_truthy(node.condition)
if truth is False:
result: list[Statement] = []
if node.initializer is not None:
result.append(Ps1ExpressionStatement(expression=node.initializer))
return result
if node.body is not None and _body_breaks_unconditionally(node.body.body):
result = []
if node.initializer is not None:
result.append(Ps1ExpressionStatement(expression=node.initializer))
body = list(node.body.body[:-1])
if truth is True or node.condition is None:
result.extend(body)
else:
result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))]))
return result
return None
@staticmethod
def _prune_if(node: Ps1IfStatement) -> list[Statement] | None:
kept_clauses: list[tuple] = []
for condition, block in node.clauses:
truth = _is_truthy(condition)
if truth is True:
return list(block.body)
if truth is False:
continue
kept_clauses.append((condition, block))
kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:])
break
else:
if node.else_block is not None:
return list(node.else_block.body)
return []
if len(kept_clauses) == len(node.clauses):
return None
node.clauses[:] = kept_clauses
return None
@staticmethod
def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None:
if node.regex or node.wildcard or node.exact or node.file:
return None
value = node.value
if isinstance(value, Ps1IntegerLiteral):
target_int = value.value
target_str = None
elif isinstance(value, Ps1StringLiteral):
target_str = value.value.lower()
target_int = None
else:
return None
default_body: list[Statement] | None = None
for condition, block in node.clauses:
if condition is None:
default_body = list(block.body)
continue
if target_int is not None and isinstance(condition, Ps1IntegerLiteral):
if condition.value == target_int:
return list(block.body)
elif target_str is not None and isinstance(condition, Ps1StringLiteral):
if condition.value.lower() == target_str:
return list(block.body)
elif target_int is not None and isinstance(condition, Ps1StringLiteral):
try:
if int(condition.value) == target_int:
return list(block.body)
except ValueError:
pass
elif target_str is not None and isinstance(condition, Ps1IntegerLiteral):
try:
if int(target_str) == condition.value:
return list(block.body)
except ValueError:
pass
if default_body is not None:
return default_body
return []
Classes
class Ps1DeadCodeElimination-
Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values.
Expand source code Browse git
class Ps1DeadCodeElimination(Transformer): """ Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values. """ def visit(self, node: Node): for parent in list(node.walk()): if isinstance(parent, Ps1SubExpression): continue body = _get_body(parent) if body is None: continue new_body = self._prune_body(body, isinstance(parent, Ps1Script)) if new_body is not body: body.clear() body.extend(new_body) for stmt in new_body: stmt.parent = parent self.mark_changed() def _prune_body( self, body: list[Statement], is_script_level: bool = False, ) -> list[Statement]: result: list[Statement] = [] changed = False prune_constants = not is_script_level or any( not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression)) for s in body ) for stmt in body: if ( prune_constants and isinstance(stmt, Ps1ExpressionStatement) and _is_pure_constant(stmt.expression) ): changed = True continue replacement = self._try_prune(stmt) if replacement is not None: result.extend(replacement) changed = True else: result.append(stmt) return result if changed else body def _try_prune(self, stmt: Statement) -> list[Statement] | None: if isinstance(stmt, Ps1WhileLoop): return self._prune_while(stmt) if isinstance(stmt, Ps1DoWhileLoop): return self._prune_do_while(stmt) if isinstance(stmt, Ps1DoUntilLoop): return self._prune_do_until(stmt) if isinstance(stmt, Ps1ForLoop): return self._prune_for(stmt) if isinstance(stmt, Ps1IfStatement): return self._prune_if(stmt) if isinstance(stmt, Ps1SwitchStatement): return self._prune_switch(stmt) return None @staticmethod def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None: truth = _is_truthy(node.condition) if truth is False: return [] if node.body is not None and _body_breaks_unconditionally(node.body.body): body = list(node.body.body[:-1]) if truth is True or node.condition is None: return body return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])] return None @staticmethod def _prune_do_while(node: Ps1DoWhileLoop) -> list[Statement] | None: if node.body is not None: if _is_truthy(node.condition) is False: return list(node.body.body) if _body_breaks_unconditionally(node.body.body): return list(node.body.body[:-1]) return None @staticmethod def _prune_do_until(node: Ps1DoUntilLoop) -> list[Statement] | None: if node.body is not None: if _is_truthy(node.condition) is True: return list(node.body.body) if _body_breaks_unconditionally(node.body.body): return list(node.body.body[:-1]) return None @staticmethod def _prune_for(node: Ps1ForLoop) -> list[Statement] | None: truth = _evaluate_for_condition(node) if truth is None: truth = _is_truthy(node.condition) if truth is False: result: list[Statement] = [] if node.initializer is not None: result.append(Ps1ExpressionStatement(expression=node.initializer)) return result if node.body is not None and _body_breaks_unconditionally(node.body.body): result = [] if node.initializer is not None: result.append(Ps1ExpressionStatement(expression=node.initializer)) body = list(node.body.body[:-1]) if truth is True or node.condition is None: result.extend(body) else: result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))])) return result return None @staticmethod def _prune_if(node: Ps1IfStatement) -> list[Statement] | None: kept_clauses: list[tuple] = [] for condition, block in node.clauses: truth = _is_truthy(condition) if truth is True: return list(block.body) if truth is False: continue kept_clauses.append((condition, block)) kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:]) break else: if node.else_block is not None: return list(node.else_block.body) return [] if len(kept_clauses) == len(node.clauses): return None node.clauses[:] = kept_clauses return None @staticmethod def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None: if node.regex or node.wildcard or node.exact or node.file: return None value = node.value if isinstance(value, Ps1IntegerLiteral): target_int = value.value target_str = None elif isinstance(value, Ps1StringLiteral): target_str = value.value.lower() target_int = None else: return None default_body: list[Statement] | None = None for condition, block in node.clauses: if condition is None: default_body = list(block.body) continue if target_int is not None and isinstance(condition, Ps1IntegerLiteral): if condition.value == target_int: return list(block.body) elif target_str is not None and isinstance(condition, Ps1StringLiteral): if condition.value.lower() == target_str: return list(block.body) elif target_int is not None and isinstance(condition, Ps1StringLiteral): try: if int(condition.value) == target_int: return list(block.body) except ValueError: pass elif target_str is not None and isinstance(condition, Ps1IntegerLiteral): try: if int(target_str) == condition.value: return list(block.body) except ValueError: pass if default_body is not None: return default_body return []Ancestors
Methods
def visit(self, node)-
Expand source code Browse git
def visit(self, node: Node): for parent in list(node.walk()): if isinstance(parent, Ps1SubExpression): continue body = _get_body(parent) if body is None: continue new_body = self._prune_body(body, isinstance(parent, Ps1Script)) if new_body is not body: body.clear() body.extend(new_body) for stmt in new_body: stmt.parent = parent self.mark_changed()