Module refinery.lib.scripts.ps1.deobfuscation.folding

PowerShell constant folding transforms.

Expand source code Browse git
"""
PowerShell constant folding transforms.
"""
from __future__ import annotations

import base64
import codecs
import re

from refinery.lib.scripts import Node, Transformer
from refinery.lib.scripts.ps1.deobfuscation._helpers import (
    _ENCODING_MAP,
    _KNOWN_ALIAS,
    SIMPLE_IDENTIFIER,
    _case_normalize_name,
    _collect_int_arguments,
    _collect_string_arguments,
    _make_string_literal,
    _string_value,
    _unwrap_paren_to_array,
)
from refinery.lib.scripts.ps1.model import (
    Expression,
    Ps1AccessKind,
    Ps1ArrayExpression,
    Ps1ArrayLiteral,
    Ps1BinaryExpression,
    Ps1CommandInvocation,
    Ps1ExpandableString,
    Ps1ExpressionStatement,
    Ps1FunctionDefinition,
    Ps1IndexExpression,
    Ps1IntegerLiteral,
    Ps1InvokeMember,
    Ps1MemberAccess,
    Ps1ParenExpression,
    Ps1ScopeModifier,
    Ps1StringLiteral,
    Ps1TypeExpression,
    Ps1UnaryExpression,
    Ps1Variable,
)

_SYSTEM_CONVERT_NAMES = frozenset({
    'system.convert',
})

_SYSTEM_TEXT_ENCODING_NAMES = frozenset({
    'system.text.encoding',
    'text.encoding',
})

_STRING_TYPE_NAMES = frozenset({
    'string',
    'system.string',
})


def _is_static_convert_call(node: Ps1InvokeMember) -> bool:
    if node.access != Ps1AccessKind.STATIC:
        return False
    if not isinstance(node.object, Ps1TypeExpression):
        return False
    return node.object.name.lower().replace(' ', '') in _SYSTEM_CONVERT_NAMES


def _is_static_encoding_chain(node: Ps1InvokeMember) -> tuple[str, bool] | None:
    member_name = node.member if isinstance(node.member, str) else None
    if member_name is None or member_name.lower() != 'getstring':
        return None
    obj = node.object
    if not isinstance(obj, Ps1MemberAccess):
        return None
    if obj.access != Ps1AccessKind.STATIC:
        return None
    if not isinstance(obj.object, Ps1TypeExpression):
        return None
    type_name = obj.object.name.lower().replace(' ', '')
    if type_name not in _SYSTEM_TEXT_ENCODING_NAMES:
        return None
    encoding_name = obj.member if isinstance(obj.member, str) else None
    if encoding_name is None:
        return None
    return encoding_name, True


def _unwrap_to_array_literal(node: Expression) -> Ps1ArrayLiteral | None:
    """
    Unwrap parentheses and array expressions to find an inner
    `Ps1ArrayLiteral`.
    """
    while isinstance(node, Ps1ParenExpression) and node.expression is not None:
        node = node.expression
    if isinstance(node, Ps1ArrayLiteral):
        return node
    if isinstance(node, Ps1ArrayExpression) and len(node.body) == 1:
        stmt = node.body[0]
        if isinstance(stmt, Ps1ExpressionStatement) and isinstance(stmt.expression, Ps1ArrayLiteral):
            return stmt.expression
    return None


def _escape_for_expandable(text: str) -> str:
    """
    Escape characters that are special inside double-quoted strings.
    """
    return text.replace('`', '``').replace('$', '`$')


def _variable_raw(var: Ps1Variable) -> str:
    """
    Produce the braced variable reference for use inside an expandable string.
    """
    prefix = '@' if var.splatted else '$'
    scope = var.scope.value
    if scope:
        return F'{prefix}{{{scope}:{var.name}}}'
    return F'{prefix}{{{var.name}}}'


def _variable_string_to_expandable(
    var: Ps1Variable,
    text: str,
    *,
    var_first: bool,
) -> Ps1ExpandableString:
    """
    Fold `$var + 'text'` or `'text' + $var` into a
    `refinery.lib.scripts.ps1.model.Ps1ExpandableString`.
    """
    escaped = _escape_for_expandable(text)
    var_raw = _variable_raw(var)
    text_part = Ps1StringLiteral(value=text, raw=F"'{text}'")
    if var_first:
        raw = F'"{var_raw}{escaped}"'
        parts = [var, text_part]
    else:
        raw = F'"{escaped}{var_raw}"'
        parts = [text_part, var]
    return Ps1ExpandableString(parts=parts, raw=raw)


class Ps1ConstantFolding(Transformer):

    def visit_Ps1CommandInvocation(self, node: Ps1CommandInvocation):
        self.generic_visit(node)
        if isinstance(node.name, Ps1StringLiteral):
            name_lower = node.name.value.lower()
            target = _KNOWN_ALIAS.get(name_lower)
            if target is not None and target != node.name.value:
                if not self._has_function_definition(node, name_lower):
                    node.name = Ps1StringLiteral(
                        offset=node.name.offset,
                        value=target,
                        raw=target,
                    )
                    node.name.parent = node
                    self.mark_changed()
        return None

    @staticmethod
    def _has_function_definition(node: Node, name_lower: str) -> bool:
        cursor = node.parent
        while cursor is not None:
            if isinstance(cursor, Ps1FunctionDefinition):
                if cursor.name and cursor.name.lower() == name_lower:
                    return True
            cursor = cursor.parent
        root = node
        while root.parent is not None:
            root = root.parent
        for n in root.walk():
            if isinstance(n, Ps1FunctionDefinition) and n.name:
                if n.name.lower() == name_lower:
                    return True
        return False

    def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
        self.generic_visit(node)
        if node.operator.lower() != '-join' or node.operand is None:
            return None
        # -Join on a scalar string is a no-op in PowerShell.
        scalar = _string_value(node.operand)
        if scalar is not None:
            return _make_string_literal(scalar)
        array = _unwrap_to_array_literal(node.operand)
        if array is None:
            return None
        args = _collect_string_arguments(array)
        if args is None:
            return None
        return _make_string_literal(''.join(args))

    def visit_Ps1IndexExpression(self, node: Ps1IndexExpression):
        self.generic_visit(node)
        obj_str = _string_value(node.object) if node.object else None
        if obj_str is None or node.index is None:
            return None
        if isinstance(node.index, Ps1IntegerLiteral):
            idx = node.index.value
            if 0 <= idx < len(obj_str):
                return _make_string_literal(obj_str[idx])
            return None
        array = _unwrap_to_array_literal(node.index)
        if array is None and isinstance(node.index, Ps1ArrayLiteral):
            array = node.index
        if array is not None:
            chars: list[Expression] = []
            for elem in array.elements:
                if not isinstance(elem, Ps1IntegerLiteral):
                    return None
                idx = elem.value
                if idx < 0 or idx >= len(obj_str):
                    return None
                chars.append(_make_string_literal(obj_str[idx]))
            return Ps1ArrayLiteral(elements=chars)
        return None

    def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
        self.generic_visit(node)
        if isinstance(node.member, Ps1StringLiteral):
            name = node.member.value
            if SIMPLE_IDENTIFIER.match(name):
                node.member = name
                self.mark_changed()
        member_name = node.member if isinstance(node.member, str) else None
        if member_name is not None:
            normalized = _case_normalize_name(member_name)
            if normalized != member_name:
                node.member = normalized
                self.mark_changed()
            member_name = node.member
        if member_name is not None and member_name.lower() == 'tostring':
            if len(node.arguments) == 0:
                obj_str = _string_value(node.object) if node.object else None
                if obj_str is not None:
                    return _make_string_literal(obj_str)
        if member_name is not None and member_name.lower() == 'replace':
            if len(node.arguments) == 2:
                obj_str = _string_value(node.object) if node.object else None
                needle_str = _string_value(node.arguments[0])
                insert_str = _string_value(node.arguments[1])
                if obj_str is not None and needle_str is not None and insert_str is not None:
                    result = obj_str.replace(needle_str, insert_str)
                    return _make_string_literal(result)
        if member_name is not None and member_name.lower() == 'split':
            if len(node.arguments) == 1:
                obj_str = _string_value(node.object) if node.object else None
                sep_str = _string_value(node.arguments[0])
                if obj_str is not None and sep_str is not None and sep_str:
                    pattern = '[' + re.escape(sep_str) + ']'
                    parts = re.split(pattern, obj_str)
                    elements: list[Expression] = [_make_string_literal(p) for p in parts]
                    return Ps1ArrayLiteral(elements=elements)
        if member_name is not None and member_name.lower() == 'invoke':
            if isinstance(node.object, Ps1MemberAccess):
                return Ps1InvokeMember(
                    offset=node.offset,
                    object=node.object.object,
                    member=node.object.member,
                    arguments=node.arguments,
                    access=node.object.access,
                )
        if _is_static_convert_call(node):
            if member_name is not None and member_name.lower() == 'frombase64string':
                if len(node.arguments) == 1:
                    b64_str = _string_value(node.arguments[0])
                    if b64_str is not None:
                        try:
                            decoded = base64.b64decode(b64_str)
                        except Exception:
                            return None
                        elements = [
                            Ps1IntegerLiteral(value=b, raw=F'0x{b:02X}')
                            for b in decoded
                        ]
                        array = Ps1ArrayLiteral(elements=elements)
                        return Ps1ArrayExpression(
                            body=[Ps1ExpressionStatement(expression=array)])
        enc_info = _is_static_encoding_chain(node)
        if enc_info is not None:
            encoding_name, _ = enc_info
            if len(node.arguments) == 1:
                arg = _unwrap_paren_to_array(node.arguments[0])
                if isinstance(arg, Ps1ArrayExpression) and len(arg.body) == 1:
                    stmt = arg.body[0]
                    if isinstance(stmt, Ps1ExpressionStatement) and stmt.expression:
                        arg = stmt.expression
                int_values = _collect_int_arguments(arg)
                if int_values is not None:
                    try:
                        raw_bytes = bytearray(int_values)
                    except (ValueError, OverflowError):
                        return None
                    encoding = _ENCODING_MAP.get(
                        encoding_name.lower(), encoding_name)
                    try:
                        codecs.lookup(encoding)
                    except LookupError:
                        encoding = 'utf-8'
                    try:
                        decoded = raw_bytes.decode(encoding)
                    except Exception:
                        return None
                    return _make_string_literal(decoded)
        if (
            node.access == Ps1AccessKind.STATIC
            and isinstance(node.object, Ps1TypeExpression)
            and node.object.name.lower().replace(' ', '') in _STRING_TYPE_NAMES
            and member_name is not None
            and member_name.lower() == 'join'
            and len(node.arguments) == 2
        ):
            separator = _string_value(node.arguments[0])
            if separator is not None:
                second = node.arguments[1]
                scalar = _string_value(second)
                if scalar is not None:
                    return _make_string_literal(scalar)
                array = _unwrap_to_array_literal(second)
                if array is not None:
                    args = _collect_string_arguments(array)
                    if args is not None:
                        return _make_string_literal(separator.join(args))
        return None

    _ARITHMETIC_OPS = {
        '+' : int.__add__,
        '-' : int.__sub__,
        '*' : int.__mul__,
        '/' : int.__floordiv__,
        '%' : int.__mod__,
        '-band': int.__and__,
        '-bor' : int.__or__,
        '-bxor': int.__xor__,
        '-shl' : int.__lshift__,
        '-shr' : int.__rshift__,
    }

    _COMPARISON_OPS = {
        '-eq': int.__eq__,
        '-ne': int.__ne__,
        '-lt': int.__lt__,
        '-le': int.__le__,
        '-gt': int.__gt__,
        '-ge': int.__ge__,
    }

    def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
        self.generic_visit(node)
        op = node.operator.lower()
        if op == '-f':
            return self._handle_format(node)
        if op == '+':
            return self._handle_concat(node) or self._handle_arithmetic(node, op)
        if op == '-join':
            return self._handle_binary_join(node)
        if op in ('-replace', '-creplace', '-ireplace'):
            return self._handle_binary_replace(node, op)
        if op in ('-split', '-csplit', '-isplit'):
            return self._handle_binary_split(node, op)
        return self._handle_comparison(node, op) or self._handle_arithmetic(node, op)

    @staticmethod
    def _unwrap_integer(node: Expression | None) -> Ps1IntegerLiteral | None:
        while isinstance(node, Ps1ParenExpression):
            node = node.expression
        if isinstance(node, Ps1IntegerLiteral):
            return node
        if (
            isinstance(node, Ps1Variable)
            and node.scope == Ps1ScopeModifier.NONE
            and node.name.lower() == 'null'
        ):
            return Ps1IntegerLiteral(value=0, raw='0')
        if isinstance(node, Ps1UnaryExpression) and node.operator == '-':
            inner = node.operand
            while isinstance(inner, Ps1ParenExpression):
                inner = inner.expression
            if isinstance(inner, Ps1IntegerLiteral):
                return Ps1IntegerLiteral(value=-inner.value, raw=str(-inner.value))
        return None

    def _handle_arithmetic(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = self._unwrap_integer(node.left)
        right = self._unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = self._ARITHMETIC_OPS.get(op)
        if fn is None:
            return None
        try:
            result = fn(left.value, right.value)
        except (ZeroDivisionError, ValueError, OverflowError):
            return None
        return Ps1IntegerLiteral(value=result, raw=str(result))

    def _handle_comparison(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = self._unwrap_integer(node.left)
        right = self._unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = self._COMPARISON_OPS.get(op)
        if fn is None:
            return None
        result = fn(left.value, right.value)
        return Ps1Variable(name='True' if result else 'False')

    def _handle_format(self, node: Ps1BinaryExpression) -> Expression | None:
        fmt_str = _string_value(node.left) if node.left else None
        if fmt_str is None or node.right is None:
            return None
        args = _collect_string_arguments(node.right)
        if args is None:
            return None
        try:
            def replacer(m: re.Match) -> str:
                full = m.group(0)
                if full == '{{':
                    return '{'
                if full == '}}':
                    return '}'
                idx = int(m.group(1))
                return args[idx]
            result = re.sub(r'\{\{|\}\}|\{(\d+)\}', replacer, fmt_str)
        except (IndexError, ValueError):
            return None
        return _make_string_literal(result)

    def _handle_concat(self, node: Ps1BinaryExpression) -> Expression | None:
        left_str = _string_value(node.left) if node.left else None
        right_str = _string_value(node.right) if node.right else None
        if left_str is not None and right_str is not None:
            return _make_string_literal(left_str + right_str)
        if right_str is not None and isinstance(node.left, Ps1BinaryExpression):
            if node.left.operator == '+':
                inner_right_str = _string_value(node.left.right) if node.left.right else None
                if inner_right_str is not None:
                    node.left.right = _make_string_literal(inner_right_str + right_str)
                    return node.left
        if right_str is not None and isinstance(node.left, Ps1ArrayLiteral):
            elements = list(node.left.elements)
            elements.append(_make_string_literal(right_str))
            return Ps1ArrayLiteral(elements=elements)
        is_inner_concat = (
            isinstance(node.parent, Ps1BinaryExpression)
            and node.parent.operator == '+'
            and node.parent.left is node
        )
        if not is_inner_concat:
            if isinstance(node.left, Ps1Variable) and right_str is not None:
                return _variable_string_to_expandable(node.left, right_str, var_first=True)
            if isinstance(node.right, Ps1Variable) and left_str is not None:
                return _variable_string_to_expandable(node.right, left_str, var_first=False)
        return None

    def _handle_binary_join(self, node: Ps1BinaryExpression) -> Expression | None:
        separator = _string_value(node.right) if node.right else None
        if separator is None or node.left is None:
            return None
        # Binary -Join on a scalar string is a no-op.
        scalar = _string_value(node.left)
        if scalar is not None:
            return _make_string_literal(scalar)
        array = _unwrap_to_array_literal(node.left)
        if array is None and isinstance(node.left, Ps1ArrayLiteral):
            array = node.left
        if array is None:
            return None
        args = _collect_string_arguments(array)
        if args is None:
            return None
        return _make_string_literal(separator.join(args))

    def _handle_binary_replace(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        haystack = _string_value(node.left) if node.left else None
        if haystack is None or node.right is None:
            return None
        if isinstance(node.right, Ps1ArrayLiteral) and len(node.right.elements) == 2:
            needle_str = _string_value(node.right.elements[0])
            insert_str = _string_value(node.right.elements[1])
        else:
            return None
        if needle_str is None or insert_str is None:
            return None
        flags = re.IGNORECASE if op != '-creplace' else 0
        try:
            result = re.sub(needle_str, lambda _: insert_str, haystack, flags=flags)
        except re.error:
            return None
        return _make_string_literal(result)

    def _handle_binary_split(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        if node.right is None or node.left is None:
            return None
        pattern_str = _string_value(node.right)
        if pattern_str is None:
            return None
        flags = re.IGNORECASE if op != '-csplit' else 0
        # Collect input strings: either a single string or an array of strings.
        left_str = _string_value(node.left)
        if left_str is not None:
            inputs = [left_str]
        else:
            array = _unwrap_to_array_literal(node.left)
            if array is None and isinstance(node.left, Ps1ArrayLiteral):
                array = node.left
            if array is None:
                return None
            inputs_opt = _collect_string_arguments(array)
            if inputs_opt is None:
                return None
            inputs = inputs_opt
        try:
            parts: list[str] = []
            for s in inputs:
                parts.extend(re.split(pattern_str, s, flags=flags))
        except re.error:
            return None
        elements: list[Expression] = [_make_string_literal(p) for p in parts]
        return Ps1ArrayLiteral(elements=elements)

Classes

class Ps1ConstantFolding

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class Ps1ConstantFolding(Transformer):

    def visit_Ps1CommandInvocation(self, node: Ps1CommandInvocation):
        self.generic_visit(node)
        if isinstance(node.name, Ps1StringLiteral):
            name_lower = node.name.value.lower()
            target = _KNOWN_ALIAS.get(name_lower)
            if target is not None and target != node.name.value:
                if not self._has_function_definition(node, name_lower):
                    node.name = Ps1StringLiteral(
                        offset=node.name.offset,
                        value=target,
                        raw=target,
                    )
                    node.name.parent = node
                    self.mark_changed()
        return None

    @staticmethod
    def _has_function_definition(node: Node, name_lower: str) -> bool:
        cursor = node.parent
        while cursor is not None:
            if isinstance(cursor, Ps1FunctionDefinition):
                if cursor.name and cursor.name.lower() == name_lower:
                    return True
            cursor = cursor.parent
        root = node
        while root.parent is not None:
            root = root.parent
        for n in root.walk():
            if isinstance(n, Ps1FunctionDefinition) and n.name:
                if n.name.lower() == name_lower:
                    return True
        return False

    def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
        self.generic_visit(node)
        if node.operator.lower() != '-join' or node.operand is None:
            return None
        # -Join on a scalar string is a no-op in PowerShell.
        scalar = _string_value(node.operand)
        if scalar is not None:
            return _make_string_literal(scalar)
        array = _unwrap_to_array_literal(node.operand)
        if array is None:
            return None
        args = _collect_string_arguments(array)
        if args is None:
            return None
        return _make_string_literal(''.join(args))

    def visit_Ps1IndexExpression(self, node: Ps1IndexExpression):
        self.generic_visit(node)
        obj_str = _string_value(node.object) if node.object else None
        if obj_str is None or node.index is None:
            return None
        if isinstance(node.index, Ps1IntegerLiteral):
            idx = node.index.value
            if 0 <= idx < len(obj_str):
                return _make_string_literal(obj_str[idx])
            return None
        array = _unwrap_to_array_literal(node.index)
        if array is None and isinstance(node.index, Ps1ArrayLiteral):
            array = node.index
        if array is not None:
            chars: list[Expression] = []
            for elem in array.elements:
                if not isinstance(elem, Ps1IntegerLiteral):
                    return None
                idx = elem.value
                if idx < 0 or idx >= len(obj_str):
                    return None
                chars.append(_make_string_literal(obj_str[idx]))
            return Ps1ArrayLiteral(elements=chars)
        return None

    def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
        self.generic_visit(node)
        if isinstance(node.member, Ps1StringLiteral):
            name = node.member.value
            if SIMPLE_IDENTIFIER.match(name):
                node.member = name
                self.mark_changed()
        member_name = node.member if isinstance(node.member, str) else None
        if member_name is not None:
            normalized = _case_normalize_name(member_name)
            if normalized != member_name:
                node.member = normalized
                self.mark_changed()
            member_name = node.member
        if member_name is not None and member_name.lower() == 'tostring':
            if len(node.arguments) == 0:
                obj_str = _string_value(node.object) if node.object else None
                if obj_str is not None:
                    return _make_string_literal(obj_str)
        if member_name is not None and member_name.lower() == 'replace':
            if len(node.arguments) == 2:
                obj_str = _string_value(node.object) if node.object else None
                needle_str = _string_value(node.arguments[0])
                insert_str = _string_value(node.arguments[1])
                if obj_str is not None and needle_str is not None and insert_str is not None:
                    result = obj_str.replace(needle_str, insert_str)
                    return _make_string_literal(result)
        if member_name is not None and member_name.lower() == 'split':
            if len(node.arguments) == 1:
                obj_str = _string_value(node.object) if node.object else None
                sep_str = _string_value(node.arguments[0])
                if obj_str is not None and sep_str is not None and sep_str:
                    pattern = '[' + re.escape(sep_str) + ']'
                    parts = re.split(pattern, obj_str)
                    elements: list[Expression] = [_make_string_literal(p) for p in parts]
                    return Ps1ArrayLiteral(elements=elements)
        if member_name is not None and member_name.lower() == 'invoke':
            if isinstance(node.object, Ps1MemberAccess):
                return Ps1InvokeMember(
                    offset=node.offset,
                    object=node.object.object,
                    member=node.object.member,
                    arguments=node.arguments,
                    access=node.object.access,
                )
        if _is_static_convert_call(node):
            if member_name is not None and member_name.lower() == 'frombase64string':
                if len(node.arguments) == 1:
                    b64_str = _string_value(node.arguments[0])
                    if b64_str is not None:
                        try:
                            decoded = base64.b64decode(b64_str)
                        except Exception:
                            return None
                        elements = [
                            Ps1IntegerLiteral(value=b, raw=F'0x{b:02X}')
                            for b in decoded
                        ]
                        array = Ps1ArrayLiteral(elements=elements)
                        return Ps1ArrayExpression(
                            body=[Ps1ExpressionStatement(expression=array)])
        enc_info = _is_static_encoding_chain(node)
        if enc_info is not None:
            encoding_name, _ = enc_info
            if len(node.arguments) == 1:
                arg = _unwrap_paren_to_array(node.arguments[0])
                if isinstance(arg, Ps1ArrayExpression) and len(arg.body) == 1:
                    stmt = arg.body[0]
                    if isinstance(stmt, Ps1ExpressionStatement) and stmt.expression:
                        arg = stmt.expression
                int_values = _collect_int_arguments(arg)
                if int_values is not None:
                    try:
                        raw_bytes = bytearray(int_values)
                    except (ValueError, OverflowError):
                        return None
                    encoding = _ENCODING_MAP.get(
                        encoding_name.lower(), encoding_name)
                    try:
                        codecs.lookup(encoding)
                    except LookupError:
                        encoding = 'utf-8'
                    try:
                        decoded = raw_bytes.decode(encoding)
                    except Exception:
                        return None
                    return _make_string_literal(decoded)
        if (
            node.access == Ps1AccessKind.STATIC
            and isinstance(node.object, Ps1TypeExpression)
            and node.object.name.lower().replace(' ', '') in _STRING_TYPE_NAMES
            and member_name is not None
            and member_name.lower() == 'join'
            and len(node.arguments) == 2
        ):
            separator = _string_value(node.arguments[0])
            if separator is not None:
                second = node.arguments[1]
                scalar = _string_value(second)
                if scalar is not None:
                    return _make_string_literal(scalar)
                array = _unwrap_to_array_literal(second)
                if array is not None:
                    args = _collect_string_arguments(array)
                    if args is not None:
                        return _make_string_literal(separator.join(args))
        return None

    _ARITHMETIC_OPS = {
        '+' : int.__add__,
        '-' : int.__sub__,
        '*' : int.__mul__,
        '/' : int.__floordiv__,
        '%' : int.__mod__,
        '-band': int.__and__,
        '-bor' : int.__or__,
        '-bxor': int.__xor__,
        '-shl' : int.__lshift__,
        '-shr' : int.__rshift__,
    }

    _COMPARISON_OPS = {
        '-eq': int.__eq__,
        '-ne': int.__ne__,
        '-lt': int.__lt__,
        '-le': int.__le__,
        '-gt': int.__gt__,
        '-ge': int.__ge__,
    }

    def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
        self.generic_visit(node)
        op = node.operator.lower()
        if op == '-f':
            return self._handle_format(node)
        if op == '+':
            return self._handle_concat(node) or self._handle_arithmetic(node, op)
        if op == '-join':
            return self._handle_binary_join(node)
        if op in ('-replace', '-creplace', '-ireplace'):
            return self._handle_binary_replace(node, op)
        if op in ('-split', '-csplit', '-isplit'):
            return self._handle_binary_split(node, op)
        return self._handle_comparison(node, op) or self._handle_arithmetic(node, op)

    @staticmethod
    def _unwrap_integer(node: Expression | None) -> Ps1IntegerLiteral | None:
        while isinstance(node, Ps1ParenExpression):
            node = node.expression
        if isinstance(node, Ps1IntegerLiteral):
            return node
        if (
            isinstance(node, Ps1Variable)
            and node.scope == Ps1ScopeModifier.NONE
            and node.name.lower() == 'null'
        ):
            return Ps1IntegerLiteral(value=0, raw='0')
        if isinstance(node, Ps1UnaryExpression) and node.operator == '-':
            inner = node.operand
            while isinstance(inner, Ps1ParenExpression):
                inner = inner.expression
            if isinstance(inner, Ps1IntegerLiteral):
                return Ps1IntegerLiteral(value=-inner.value, raw=str(-inner.value))
        return None

    def _handle_arithmetic(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = self._unwrap_integer(node.left)
        right = self._unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = self._ARITHMETIC_OPS.get(op)
        if fn is None:
            return None
        try:
            result = fn(left.value, right.value)
        except (ZeroDivisionError, ValueError, OverflowError):
            return None
        return Ps1IntegerLiteral(value=result, raw=str(result))

    def _handle_comparison(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = self._unwrap_integer(node.left)
        right = self._unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = self._COMPARISON_OPS.get(op)
        if fn is None:
            return None
        result = fn(left.value, right.value)
        return Ps1Variable(name='True' if result else 'False')

    def _handle_format(self, node: Ps1BinaryExpression) -> Expression | None:
        fmt_str = _string_value(node.left) if node.left else None
        if fmt_str is None or node.right is None:
            return None
        args = _collect_string_arguments(node.right)
        if args is None:
            return None
        try:
            def replacer(m: re.Match) -> str:
                full = m.group(0)
                if full == '{{':
                    return '{'
                if full == '}}':
                    return '}'
                idx = int(m.group(1))
                return args[idx]
            result = re.sub(r'\{\{|\}\}|\{(\d+)\}', replacer, fmt_str)
        except (IndexError, ValueError):
            return None
        return _make_string_literal(result)

    def _handle_concat(self, node: Ps1BinaryExpression) -> Expression | None:
        left_str = _string_value(node.left) if node.left else None
        right_str = _string_value(node.right) if node.right else None
        if left_str is not None and right_str is not None:
            return _make_string_literal(left_str + right_str)
        if right_str is not None and isinstance(node.left, Ps1BinaryExpression):
            if node.left.operator == '+':
                inner_right_str = _string_value(node.left.right) if node.left.right else None
                if inner_right_str is not None:
                    node.left.right = _make_string_literal(inner_right_str + right_str)
                    return node.left
        if right_str is not None and isinstance(node.left, Ps1ArrayLiteral):
            elements = list(node.left.elements)
            elements.append(_make_string_literal(right_str))
            return Ps1ArrayLiteral(elements=elements)
        is_inner_concat = (
            isinstance(node.parent, Ps1BinaryExpression)
            and node.parent.operator == '+'
            and node.parent.left is node
        )
        if not is_inner_concat:
            if isinstance(node.left, Ps1Variable) and right_str is not None:
                return _variable_string_to_expandable(node.left, right_str, var_first=True)
            if isinstance(node.right, Ps1Variable) and left_str is not None:
                return _variable_string_to_expandable(node.right, left_str, var_first=False)
        return None

    def _handle_binary_join(self, node: Ps1BinaryExpression) -> Expression | None:
        separator = _string_value(node.right) if node.right else None
        if separator is None or node.left is None:
            return None
        # Binary -Join on a scalar string is a no-op.
        scalar = _string_value(node.left)
        if scalar is not None:
            return _make_string_literal(scalar)
        array = _unwrap_to_array_literal(node.left)
        if array is None and isinstance(node.left, Ps1ArrayLiteral):
            array = node.left
        if array is None:
            return None
        args = _collect_string_arguments(array)
        if args is None:
            return None
        return _make_string_literal(separator.join(args))

    def _handle_binary_replace(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        haystack = _string_value(node.left) if node.left else None
        if haystack is None or node.right is None:
            return None
        if isinstance(node.right, Ps1ArrayLiteral) and len(node.right.elements) == 2:
            needle_str = _string_value(node.right.elements[0])
            insert_str = _string_value(node.right.elements[1])
        else:
            return None
        if needle_str is None or insert_str is None:
            return None
        flags = re.IGNORECASE if op != '-creplace' else 0
        try:
            result = re.sub(needle_str, lambda _: insert_str, haystack, flags=flags)
        except re.error:
            return None
        return _make_string_literal(result)

    def _handle_binary_split(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        if node.right is None or node.left is None:
            return None
        pattern_str = _string_value(node.right)
        if pattern_str is None:
            return None
        flags = re.IGNORECASE if op != '-csplit' else 0
        # Collect input strings: either a single string or an array of strings.
        left_str = _string_value(node.left)
        if left_str is not None:
            inputs = [left_str]
        else:
            array = _unwrap_to_array_literal(node.left)
            if array is None and isinstance(node.left, Ps1ArrayLiteral):
                array = node.left
            if array is None:
                return None
            inputs_opt = _collect_string_arguments(array)
            if inputs_opt is None:
                return None
            inputs = inputs_opt
        try:
            parts: list[str] = []
            for s in inputs:
                parts.extend(re.split(pattern_str, s, flags=flags))
        except re.error:
            return None
        elements: list[Expression] = [_make_string_literal(p) for p in parts]
        return Ps1ArrayLiteral(elements=elements)

Ancestors

Methods

def visit_Ps1CommandInvocation(self, node)
Expand source code Browse git
def visit_Ps1CommandInvocation(self, node: Ps1CommandInvocation):
    self.generic_visit(node)
    if isinstance(node.name, Ps1StringLiteral):
        name_lower = node.name.value.lower()
        target = _KNOWN_ALIAS.get(name_lower)
        if target is not None and target != node.name.value:
            if not self._has_function_definition(node, name_lower):
                node.name = Ps1StringLiteral(
                    offset=node.name.offset,
                    value=target,
                    raw=target,
                )
                node.name.parent = node
                self.mark_changed()
    return None
def visit_Ps1UnaryExpression(self, node)
Expand source code Browse git
def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
    self.generic_visit(node)
    if node.operator.lower() != '-join' or node.operand is None:
        return None
    # -Join on a scalar string is a no-op in PowerShell.
    scalar = _string_value(node.operand)
    if scalar is not None:
        return _make_string_literal(scalar)
    array = _unwrap_to_array_literal(node.operand)
    if array is None:
        return None
    args = _collect_string_arguments(array)
    if args is None:
        return None
    return _make_string_literal(''.join(args))
def visit_Ps1IndexExpression(self, node)
Expand source code Browse git
def visit_Ps1IndexExpression(self, node: Ps1IndexExpression):
    self.generic_visit(node)
    obj_str = _string_value(node.object) if node.object else None
    if obj_str is None or node.index is None:
        return None
    if isinstance(node.index, Ps1IntegerLiteral):
        idx = node.index.value
        if 0 <= idx < len(obj_str):
            return _make_string_literal(obj_str[idx])
        return None
    array = _unwrap_to_array_literal(node.index)
    if array is None and isinstance(node.index, Ps1ArrayLiteral):
        array = node.index
    if array is not None:
        chars: list[Expression] = []
        for elem in array.elements:
            if not isinstance(elem, Ps1IntegerLiteral):
                return None
            idx = elem.value
            if idx < 0 or idx >= len(obj_str):
                return None
            chars.append(_make_string_literal(obj_str[idx]))
        return Ps1ArrayLiteral(elements=chars)
    return None
def visit_Ps1InvokeMember(self, node)
Expand source code Browse git
def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
    self.generic_visit(node)
    if isinstance(node.member, Ps1StringLiteral):
        name = node.member.value
        if SIMPLE_IDENTIFIER.match(name):
            node.member = name
            self.mark_changed()
    member_name = node.member if isinstance(node.member, str) else None
    if member_name is not None:
        normalized = _case_normalize_name(member_name)
        if normalized != member_name:
            node.member = normalized
            self.mark_changed()
        member_name = node.member
    if member_name is not None and member_name.lower() == 'tostring':
        if len(node.arguments) == 0:
            obj_str = _string_value(node.object) if node.object else None
            if obj_str is not None:
                return _make_string_literal(obj_str)
    if member_name is not None and member_name.lower() == 'replace':
        if len(node.arguments) == 2:
            obj_str = _string_value(node.object) if node.object else None
            needle_str = _string_value(node.arguments[0])
            insert_str = _string_value(node.arguments[1])
            if obj_str is not None and needle_str is not None and insert_str is not None:
                result = obj_str.replace(needle_str, insert_str)
                return _make_string_literal(result)
    if member_name is not None and member_name.lower() == 'split':
        if len(node.arguments) == 1:
            obj_str = _string_value(node.object) if node.object else None
            sep_str = _string_value(node.arguments[0])
            if obj_str is not None and sep_str is not None and sep_str:
                pattern = '[' + re.escape(sep_str) + ']'
                parts = re.split(pattern, obj_str)
                elements: list[Expression] = [_make_string_literal(p) for p in parts]
                return Ps1ArrayLiteral(elements=elements)
    if member_name is not None and member_name.lower() == 'invoke':
        if isinstance(node.object, Ps1MemberAccess):
            return Ps1InvokeMember(
                offset=node.offset,
                object=node.object.object,
                member=node.object.member,
                arguments=node.arguments,
                access=node.object.access,
            )
    if _is_static_convert_call(node):
        if member_name is not None and member_name.lower() == 'frombase64string':
            if len(node.arguments) == 1:
                b64_str = _string_value(node.arguments[0])
                if b64_str is not None:
                    try:
                        decoded = base64.b64decode(b64_str)
                    except Exception:
                        return None
                    elements = [
                        Ps1IntegerLiteral(value=b, raw=F'0x{b:02X}')
                        for b in decoded
                    ]
                    array = Ps1ArrayLiteral(elements=elements)
                    return Ps1ArrayExpression(
                        body=[Ps1ExpressionStatement(expression=array)])
    enc_info = _is_static_encoding_chain(node)
    if enc_info is not None:
        encoding_name, _ = enc_info
        if len(node.arguments) == 1:
            arg = _unwrap_paren_to_array(node.arguments[0])
            if isinstance(arg, Ps1ArrayExpression) and len(arg.body) == 1:
                stmt = arg.body[0]
                if isinstance(stmt, Ps1ExpressionStatement) and stmt.expression:
                    arg = stmt.expression
            int_values = _collect_int_arguments(arg)
            if int_values is not None:
                try:
                    raw_bytes = bytearray(int_values)
                except (ValueError, OverflowError):
                    return None
                encoding = _ENCODING_MAP.get(
                    encoding_name.lower(), encoding_name)
                try:
                    codecs.lookup(encoding)
                except LookupError:
                    encoding = 'utf-8'
                try:
                    decoded = raw_bytes.decode(encoding)
                except Exception:
                    return None
                return _make_string_literal(decoded)
    if (
        node.access == Ps1AccessKind.STATIC
        and isinstance(node.object, Ps1TypeExpression)
        and node.object.name.lower().replace(' ', '') in _STRING_TYPE_NAMES
        and member_name is not None
        and member_name.lower() == 'join'
        and len(node.arguments) == 2
    ):
        separator = _string_value(node.arguments[0])
        if separator is not None:
            second = node.arguments[1]
            scalar = _string_value(second)
            if scalar is not None:
                return _make_string_literal(scalar)
            array = _unwrap_to_array_literal(second)
            if array is not None:
                args = _collect_string_arguments(array)
                if args is not None:
                    return _make_string_literal(separator.join(args))
    return None
def visit_Ps1BinaryExpression(self, node)
Expand source code Browse git
def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
    self.generic_visit(node)
    op = node.operator.lower()
    if op == '-f':
        return self._handle_format(node)
    if op == '+':
        return self._handle_concat(node) or self._handle_arithmetic(node, op)
    if op == '-join':
        return self._handle_binary_join(node)
    if op in ('-replace', '-creplace', '-ireplace'):
        return self._handle_binary_replace(node, op)
    if op in ('-split', '-csplit', '-isplit'):
        return self._handle_binary_split(node, op)
    return self._handle_comparison(node, op) or self._handle_arithmetic(node, op)