Module refinery.lib.scripts.ps1.deobfuscation.typecast

PowerShell type cast simplification transforms.

Expand source code Browse git
"""
PowerShell type cast simplification transforms.
"""
from __future__ import annotations

import string

from refinery.lib.scripts import Transformer
from refinery.lib.scripts.ps1.deobfuscation._helpers import (
    _collect_int_arguments,
    _collect_string_arguments,
    _make_string_literal,
    _string_value,
    _unwrap_paren_to_array,
)
from refinery.lib.scripts.ps1.model import (
    Ps1BinaryExpression,
    Ps1CastExpression,
    Ps1IntegerLiteral,
    Ps1ParenExpression,
    Ps1ScopeModifier,
    Ps1TypeExpression,
    Ps1UnaryExpression,
    Ps1Variable,
)

_INTEGER_TYPE_NAMES = frozenset({
    'byte',
    'int',
    'int16',
    'int32',
    'int64',
    'long',
    'sbyte',
    'short',
    'uint16',
    'uint32',
    'uint64',
    'ushort',
})


def _unwrap_integer(node) -> int | None:
    """
    Peel parentheses and unary negation to extract an integer value, or return None.
    """
    while isinstance(node, Ps1ParenExpression):
        node = node.expression
    if isinstance(node, Ps1IntegerLiteral):
        return node.value
    if (
        isinstance(node, Ps1Variable)
        and node.scope == Ps1ScopeModifier.NONE
        and node.name.lower() == 'null'
    ):
        return 0
    if isinstance(node, Ps1UnaryExpression) and node.operator == '-':
        inner = node.operand
        while isinstance(inner, Ps1ParenExpression):
            inner = inner.expression
        if isinstance(inner, Ps1IntegerLiteral):
            return -inner.value
    return None


class Ps1TypeCasts(Transformer):

    def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
        self.generic_visit(node)
        if node.operator.lower() != '-as':
            return None
        if not isinstance(node.right, Ps1TypeExpression):
            return None
        cast = Ps1CastExpression(
            offset=node.offset,
            type_name=node.right.name,
            operand=node.left,
        )
        return self.visit_Ps1CastExpression(cast)

    def visit_Ps1CastExpression(self, node: Ps1CastExpression):
        self.generic_visit(node)
        tn = node.type_name.lower().replace(' ', '')
        if tn in ('string', 'char[]'):
            if node.operand and _string_value(node.operand) is not None:
                return node.operand
        if tn == 'string':
            if node.operand is not None:
                inner = _unwrap_paren_to_array(node.operand)
                parts = _collect_string_arguments(inner)
                if parts is not None and len(parts) > 1:
                    return _make_string_literal(' '.join(parts))
        if tn in _INTEGER_TYPE_NAMES:
            value = _unwrap_integer(node.operand)
            if value is not None:
                return Ps1IntegerLiteral(value=value, raw=str(value))
        if tn == 'char':
            value = _unwrap_integer(node.operand)
            if value is not None:
                if value == 0:
                    return _make_string_literal('')
                try:
                    ch = chr(value)
                except (ValueError, OverflowError):
                    return None
                return _make_string_literal(ch)
        if tn == 'char[]':
            if node.operand is not None:
                inner = _unwrap_paren_to_array(node.operand)
                int_values = _collect_int_arguments(inner)
                if int_values is not None:
                    try:
                        result_bytes = bytes(int_values)
                        result = result_bytes.decode('ascii')
                        if not all(c in string.printable or c.isspace() for c in result):
                            return None
                    except (ValueError, UnicodeDecodeError, OverflowError):
                        return None
                    return _make_string_literal(result)
        if tn == 'type':
            sv = _string_value(node.operand) if node.operand else None
            if sv is not None:
                return Ps1TypeExpression(offset=node.offset, name=sv)
        return None

Classes

class Ps1TypeCasts

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class Ps1TypeCasts(Transformer):

    def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
        self.generic_visit(node)
        if node.operator.lower() != '-as':
            return None
        if not isinstance(node.right, Ps1TypeExpression):
            return None
        cast = Ps1CastExpression(
            offset=node.offset,
            type_name=node.right.name,
            operand=node.left,
        )
        return self.visit_Ps1CastExpression(cast)

    def visit_Ps1CastExpression(self, node: Ps1CastExpression):
        self.generic_visit(node)
        tn = node.type_name.lower().replace(' ', '')
        if tn in ('string', 'char[]'):
            if node.operand and _string_value(node.operand) is not None:
                return node.operand
        if tn == 'string':
            if node.operand is not None:
                inner = _unwrap_paren_to_array(node.operand)
                parts = _collect_string_arguments(inner)
                if parts is not None and len(parts) > 1:
                    return _make_string_literal(' '.join(parts))
        if tn in _INTEGER_TYPE_NAMES:
            value = _unwrap_integer(node.operand)
            if value is not None:
                return Ps1IntegerLiteral(value=value, raw=str(value))
        if tn == 'char':
            value = _unwrap_integer(node.operand)
            if value is not None:
                if value == 0:
                    return _make_string_literal('')
                try:
                    ch = chr(value)
                except (ValueError, OverflowError):
                    return None
                return _make_string_literal(ch)
        if tn == 'char[]':
            if node.operand is not None:
                inner = _unwrap_paren_to_array(node.operand)
                int_values = _collect_int_arguments(inner)
                if int_values is not None:
                    try:
                        result_bytes = bytes(int_values)
                        result = result_bytes.decode('ascii')
                        if not all(c in string.printable or c.isspace() for c in result):
                            return None
                    except (ValueError, UnicodeDecodeError, OverflowError):
                        return None
                    return _make_string_literal(result)
        if tn == 'type':
            sv = _string_value(node.operand) if node.operand else None
            if sv is not None:
                return Ps1TypeExpression(offset=node.offset, name=sv)
        return None

Ancestors

Methods

def visit_Ps1BinaryExpression(self, node)
Expand source code Browse git
def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
    self.generic_visit(node)
    if node.operator.lower() != '-as':
        return None
    if not isinstance(node.right, Ps1TypeExpression):
        return None
    cast = Ps1CastExpression(
        offset=node.offset,
        type_name=node.right.name,
        operand=node.left,
    )
    return self.visit_Ps1CastExpression(cast)
def visit_Ps1CastExpression(self, node)
Expand source code Browse git
def visit_Ps1CastExpression(self, node: Ps1CastExpression):
    self.generic_visit(node)
    tn = node.type_name.lower().replace(' ', '')
    if tn in ('string', 'char[]'):
        if node.operand and _string_value(node.operand) is not None:
            return node.operand
    if tn == 'string':
        if node.operand is not None:
            inner = _unwrap_paren_to_array(node.operand)
            parts = _collect_string_arguments(inner)
            if parts is not None and len(parts) > 1:
                return _make_string_literal(' '.join(parts))
    if tn in _INTEGER_TYPE_NAMES:
        value = _unwrap_integer(node.operand)
        if value is not None:
            return Ps1IntegerLiteral(value=value, raw=str(value))
    if tn == 'char':
        value = _unwrap_integer(node.operand)
        if value is not None:
            if value == 0:
                return _make_string_literal('')
            try:
                ch = chr(value)
            except (ValueError, OverflowError):
                return None
            return _make_string_literal(ch)
    if tn == 'char[]':
        if node.operand is not None:
            inner = _unwrap_paren_to_array(node.operand)
            int_values = _collect_int_arguments(inner)
            if int_values is not None:
                try:
                    result_bytes = bytes(int_values)
                    result = result_bytes.decode('ascii')
                    if not all(c in string.printable or c.isspace() for c in result):
                        return None
                except (ValueError, UnicodeDecodeError, OverflowError):
                    return None
                return _make_string_literal(result)
    if tn == 'type':
        sv = _string_value(node.operand) if node.operand else None
        if sv is not None:
            return Ps1TypeExpression(offset=node.offset, name=sv)
    return None