Module refinery.lib.scripts.ps1.deobfuscation.deadcode

Eliminate dead code from PowerShell scripts after constant folding.

Expand source code Browse git
"""
Eliminate dead code from PowerShell scripts after constant folding.
"""
from __future__ import annotations

from refinery.lib.scripts import Block, Expression, Node, Statement, Transformer
from refinery.lib.scripts.ps1.deobfuscation.data import COMPARISON_OPS
from refinery.lib.scripts.ps1.deobfuscation.helpers import (
    get_body,
    inside_value_producing_context,
    is_builtin_variable,
    is_truthy,
    unwrap_integer,
    unwrap_parens,
)
from refinery.lib.scripts.ps1.model import (
    Ps1AssignmentExpression,
    Ps1BinaryExpression,
    Ps1BreakStatement,
    Ps1ContinueStatement,
    Ps1DoLoop,
    Ps1ExpressionStatement,
    Ps1ForLoop,
    Ps1IfStatement,
    Ps1IntegerLiteral,
    Ps1ParenExpression,
    Ps1RealLiteral,
    Ps1ScopeModifier,
    Ps1Script,
    Ps1StringLiteral,
    Ps1SwitchStatement,
    Ps1UnaryExpression,
    Ps1Variable,
    Ps1WhileLoop,
)


def _evaluate_for_condition(node: Ps1ForLoop) -> bool | None:
    """
    Try to evaluate a for-loop condition at loop entry by substituting the initial value of the
    loop variable into the comparison. Returns the boolean result, or `None` if the pattern does not
    match.
    """
    init = node.initializer
    cond = node.condition
    if not isinstance(init, Ps1AssignmentExpression) or init.operator != '=':
        return None
    if not isinstance(init.target, Ps1Variable):
        return None
    init_val = unwrap_integer(init.value)
    if init_val is None:
        return None
    if not isinstance(cond, Ps1BinaryExpression):
        return None
    op_fn = COMPARISON_OPS.get(cond.operator.lower())
    if op_fn is None:
        return None
    var_name = init.target.name.lower()
    var_scope = init.target.scope
    left_val = _resolve_side(cond.left, var_name, var_scope, init_val.value)
    right_val = _resolve_side(cond.right, var_name, var_scope, init_val.value)
    if left_val is None or right_val is None:
        return None
    return bool(op_fn(left_val, right_val))


def _resolve_side(
    node, var_name: str, var_scope: Ps1ScopeModifier, init_val: int,
) -> int | None:
    """
    Resolve one side of a for-loop condition to an integer: if the node is the loop variable,
    return the initial value; if it is a constant integer, return that; otherwise return `None`.
    """
    node = unwrap_parens(node) if isinstance(node, Expression) else node
    if (
        isinstance(node, Ps1Variable)
        and node.name.lower() == var_name
        and node.scope == var_scope
    ):
        return init_val
    result = unwrap_integer(node)
    return result.value if result is not None else None


def _body_breaks_unconditionally(body: list[Statement]) -> bool:
    """
    Return `True` if the last statement in the body is an unlabeled break and the body contains no
    continue statements at any nesting depth. Such a loop body executes exactly once.
    """
    if not body:
        return False
    last = body[-1]
    if not isinstance(last, Ps1BreakStatement) or last.label is not None:
        return False
    for stmt in body[:-1]:
        for node in stmt.walk():
            if isinstance(node, Ps1ContinueStatement):
                return False
    return True


def _is_pure_constant(node) -> bool:
    """
    Return `True` when an expression is a side-effect-free constant that can be removed as a
    standalone statement. Only matches numeric literals and the built-in constants `$Null`,
    `$True`, and `$False` — string literals are excluded because they may represent intentional
    pipeline output.
    """
    if isinstance(node, (Ps1IntegerLiteral, Ps1RealLiteral)):
        return True
    if is_builtin_variable(node):
        return True
    if isinstance(node, Ps1ParenExpression):
        return _is_pure_constant(node.expression)
    if isinstance(node, Ps1UnaryExpression) and node.operator in ('+', '-'):
        return _is_pure_constant(node.operand)
    return False


class Ps1DeadCodeElimination(Transformer):
    """
    Remove unreachable code guarded by constant boolean conditions and resolve switch statements
    on constant values.
    """

    def visit(self, node: Node):
        for parent in list(node.walk()):
            if inside_value_producing_context(parent):
                continue
            body = get_body(parent)
            if body is None:
                continue
            new_body = self._prune_body(body, isinstance(parent, Ps1Script))
            if new_body is not body:
                body.clear()
                body.extend(new_body)
                for stmt in new_body:
                    stmt.parent = parent
                self.mark_changed()

    def _prune_body(
        self, body: list[Statement], is_script_level: bool = False,
    ) -> list[Statement]:
        result: list[Statement] = []
        changed = False
        prune_constants = not is_script_level or any(
            not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression))
            for s in body
        )
        for stmt in body:
            if (
                prune_constants
                and isinstance(stmt, Ps1ExpressionStatement)
                and _is_pure_constant(stmt.expression)
            ):
                changed = True
                continue
            replacement = self._try_prune(stmt)
            if replacement is not None:
                result.extend(replacement)
                changed = True
            else:
                result.append(stmt)
        return result if changed else body

    def _try_prune(self, stmt: Statement) -> list[Statement] | None:
        if isinstance(stmt, Ps1WhileLoop):
            return self._prune_while(stmt)
        if isinstance(stmt, Ps1DoLoop):
            return self._prune_do_loop(stmt)
        if isinstance(stmt, Ps1ForLoop):
            return self._prune_for(stmt)
        if isinstance(stmt, Ps1IfStatement):
            return self._prune_if(stmt)
        if isinstance(stmt, Ps1SwitchStatement):
            return self._prune_switch(stmt)
        return None

    @staticmethod
    def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None:
        truth = is_truthy(node.condition)
        if truth is False:
            return []
        if node.body is not None and _body_breaks_unconditionally(node.body.body):
            body = list(node.body.body[:-1])
            if truth is True or node.condition is None:
                return body
            return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])]
        return None

    @staticmethod
    def _prune_do_loop(node: Ps1DoLoop) -> list[Statement] | None:
        if node.body is not None:
            trivially_exits = (
                is_truthy(node.condition) is True if node.is_until
                else is_truthy(node.condition) is False
            )
            if trivially_exits:
                return list(node.body.body)
            if _body_breaks_unconditionally(node.body.body):
                return list(node.body.body[:-1])
        return None

    @staticmethod
    def _prune_for(node: Ps1ForLoop) -> list[Statement] | None:
        truth = _evaluate_for_condition(node)
        if truth is None:
            truth = is_truthy(node.condition)
        if truth is False:
            result: list[Statement] = []
            if node.initializer is not None:
                result.append(Ps1ExpressionStatement(expression=node.initializer))
            return result
        if node.body is not None and _body_breaks_unconditionally(node.body.body):
            result = []
            if node.initializer is not None:
                result.append(Ps1ExpressionStatement(expression=node.initializer))
            body = list(node.body.body[:-1])
            if truth is True or node.condition is None:
                result.extend(body)
            else:
                result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))]))
            return result
        return None

    @staticmethod
    def _prune_if(node: Ps1IfStatement) -> list[Statement] | None:
        kept_clauses: list[tuple] = []
        for condition, block in node.clauses:
            truth = is_truthy(condition)
            if truth is True:
                return list(block.body)
            if truth is False:
                continue
            kept_clauses.append((condition, block))
            kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:])
            break
        else:
            if node.else_block is not None:
                return list(node.else_block.body)
            return []
        if len(kept_clauses) == len(node.clauses):
            return None
        node.clauses[:] = kept_clauses
        return None

    @staticmethod
    def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None:
        if node.regex or node.wildcard or node.exact or node.file:
            return None
        value = node.value
        if isinstance(value, Ps1IntegerLiteral):
            target_int = value.value
            target_str = None
        elif isinstance(value, Ps1StringLiteral):
            target_str = value.value.lower()
            target_int = None
        else:
            return None
        default_body: list[Statement] | None = None
        for condition, block in node.clauses:
            if condition is None:
                default_body = list(block.body)
                continue
            if target_int is not None and isinstance(condition, Ps1IntegerLiteral):
                if condition.value == target_int:
                    return list(block.body)
            elif target_str is not None and isinstance(condition, Ps1StringLiteral):
                if condition.value.lower() == target_str:
                    return list(block.body)
            elif target_int is not None and isinstance(condition, Ps1StringLiteral):
                try:
                    if int(condition.value) == target_int:
                        return list(block.body)
                except ValueError:
                    pass
            elif target_str is not None and isinstance(condition, Ps1IntegerLiteral):
                try:
                    if int(target_str) == condition.value:
                        return list(block.body)
                except ValueError:
                    pass
        if default_body is not None:
            return default_body
        return []

Classes

class Ps1DeadCodeElimination

Remove unreachable code guarded by constant boolean conditions and resolve switch statements on constant values.

Expand source code Browse git
class Ps1DeadCodeElimination(Transformer):
    """
    Remove unreachable code guarded by constant boolean conditions and resolve switch statements
    on constant values.
    """

    def visit(self, node: Node):
        for parent in list(node.walk()):
            if inside_value_producing_context(parent):
                continue
            body = get_body(parent)
            if body is None:
                continue
            new_body = self._prune_body(body, isinstance(parent, Ps1Script))
            if new_body is not body:
                body.clear()
                body.extend(new_body)
                for stmt in new_body:
                    stmt.parent = parent
                self.mark_changed()

    def _prune_body(
        self, body: list[Statement], is_script_level: bool = False,
    ) -> list[Statement]:
        result: list[Statement] = []
        changed = False
        prune_constants = not is_script_level or any(
            not (isinstance(s, Ps1ExpressionStatement) and _is_pure_constant(s.expression))
            for s in body
        )
        for stmt in body:
            if (
                prune_constants
                and isinstance(stmt, Ps1ExpressionStatement)
                and _is_pure_constant(stmt.expression)
            ):
                changed = True
                continue
            replacement = self._try_prune(stmt)
            if replacement is not None:
                result.extend(replacement)
                changed = True
            else:
                result.append(stmt)
        return result if changed else body

    def _try_prune(self, stmt: Statement) -> list[Statement] | None:
        if isinstance(stmt, Ps1WhileLoop):
            return self._prune_while(stmt)
        if isinstance(stmt, Ps1DoLoop):
            return self._prune_do_loop(stmt)
        if isinstance(stmt, Ps1ForLoop):
            return self._prune_for(stmt)
        if isinstance(stmt, Ps1IfStatement):
            return self._prune_if(stmt)
        if isinstance(stmt, Ps1SwitchStatement):
            return self._prune_switch(stmt)
        return None

    @staticmethod
    def _prune_while(node: Ps1WhileLoop) -> list[Statement] | None:
        truth = is_truthy(node.condition)
        if truth is False:
            return []
        if node.body is not None and _body_breaks_unconditionally(node.body.body):
            body = list(node.body.body[:-1])
            if truth is True or node.condition is None:
                return body
            return [Ps1IfStatement(clauses=[(node.condition, Block(body=body))])]
        return None

    @staticmethod
    def _prune_do_loop(node: Ps1DoLoop) -> list[Statement] | None:
        if node.body is not None:
            trivially_exits = (
                is_truthy(node.condition) is True if node.is_until
                else is_truthy(node.condition) is False
            )
            if trivially_exits:
                return list(node.body.body)
            if _body_breaks_unconditionally(node.body.body):
                return list(node.body.body[:-1])
        return None

    @staticmethod
    def _prune_for(node: Ps1ForLoop) -> list[Statement] | None:
        truth = _evaluate_for_condition(node)
        if truth is None:
            truth = is_truthy(node.condition)
        if truth is False:
            result: list[Statement] = []
            if node.initializer is not None:
                result.append(Ps1ExpressionStatement(expression=node.initializer))
            return result
        if node.body is not None and _body_breaks_unconditionally(node.body.body):
            result = []
            if node.initializer is not None:
                result.append(Ps1ExpressionStatement(expression=node.initializer))
            body = list(node.body.body[:-1])
            if truth is True or node.condition is None:
                result.extend(body)
            else:
                result.append(Ps1IfStatement(clauses=[(node.condition, Block(body=body))]))
            return result
        return None

    @staticmethod
    def _prune_if(node: Ps1IfStatement) -> list[Statement] | None:
        kept_clauses: list[tuple] = []
        for condition, block in node.clauses:
            truth = is_truthy(condition)
            if truth is True:
                return list(block.body)
            if truth is False:
                continue
            kept_clauses.append((condition, block))
            kept_clauses.extend(node.clauses[node.clauses.index((condition, block)) + 1:])
            break
        else:
            if node.else_block is not None:
                return list(node.else_block.body)
            return []
        if len(kept_clauses) == len(node.clauses):
            return None
        node.clauses[:] = kept_clauses
        return None

    @staticmethod
    def _prune_switch(node: Ps1SwitchStatement) -> list[Statement] | None:
        if node.regex or node.wildcard or node.exact or node.file:
            return None
        value = node.value
        if isinstance(value, Ps1IntegerLiteral):
            target_int = value.value
            target_str = None
        elif isinstance(value, Ps1StringLiteral):
            target_str = value.value.lower()
            target_int = None
        else:
            return None
        default_body: list[Statement] | None = None
        for condition, block in node.clauses:
            if condition is None:
                default_body = list(block.body)
                continue
            if target_int is not None and isinstance(condition, Ps1IntegerLiteral):
                if condition.value == target_int:
                    return list(block.body)
            elif target_str is not None and isinstance(condition, Ps1StringLiteral):
                if condition.value.lower() == target_str:
                    return list(block.body)
            elif target_int is not None and isinstance(condition, Ps1StringLiteral):
                try:
                    if int(condition.value) == target_int:
                        return list(block.body)
                except ValueError:
                    pass
            elif target_str is not None and isinstance(condition, Ps1IntegerLiteral):
                try:
                    if int(target_str) == condition.value:
                        return list(block.body)
                except ValueError:
                    pass
        if default_body is not None:
            return default_body
        return []

Ancestors

Methods

def visit(self, node)
Expand source code Browse git
def visit(self, node: Node):
    for parent in list(node.walk()):
        if inside_value_producing_context(parent):
            continue
        body = get_body(parent)
        if body is None:
            continue
        new_body = self._prune_body(body, isinstance(parent, Ps1Script))
        if new_body is not body:
            body.clear()
            body.extend(new_body)
            for stmt in new_body:
                stmt.parent = parent
            self.mark_changed()