Module refinery.lib.scripts.ps1.deobfuscation.folding

PowerShell constant folding transforms.

Expand source code Browse git
"""
PowerShell constant folding transforms.
"""
from __future__ import annotations

import base64
import codecs
import re

from collections.abc import Iterator

from refinery.lib.scripts.ps1.deobfuscation.constants import PS1_ENV_CONSTANTS
from refinery.lib.scripts.ps1.deobfuscation.data import (
    COMPARISON_OPS,
    ENCODING_MAP,
)
from refinery.lib.scripts.ps1.deobfuscation.helpers import (
    LocalFunctionAwareTransformer,
    StringMethodError,
    apply_format_string,
    apply_string_method,
    collect_byte_array,
    collect_format_arguments,
    collect_int_arguments,
    collect_string_arguments,
    detect_encoding_chain,
    extract_foreach_scriptblock,
    get_body,
    get_member_name,
    is_array_reverse_call,
    is_static_type_call,
    is_truthy,
    make_string_literal,
    string_value,
    unwrap_integer,
    unwrap_parens,
    unwrap_single_paren,
    unwrap_to_array_literal,
)
from refinery.lib.scripts.ps1.deobfuscation.typenames import (
    is_known_member,
    resolve_member_type,
)
from refinery.lib.scripts.ps1.model import (
    Expression,
    Ps1ArrayExpression,
    Ps1ArrayLiteral,
    Ps1AssignmentExpression,
    Ps1BinaryExpression,
    Ps1CommandInvocation,
    Ps1ExpandableString,
    Ps1ExpressionStatement,
    Ps1HashLiteral,
    Ps1IndexExpression,
    Ps1IntegerLiteral,
    Ps1InvokeMember,
    Ps1MemberAccess,
    Ps1Pipeline,
    Ps1RangeExpression,
    Ps1ScriptBlock,
    Ps1StringLiteral,
    Ps1UnaryExpression,
    Ps1Variable,
)

_REGEX_OPTION_FLAGS: dict[str, int] = {
    'ignorecase'              : re.IGNORECASE,
    'multiline'               : re.MULTILINE,
    'singleline'              : re.DOTALL,
    'ignorepatternwhitespace' : re.VERBOSE,
    'none'                    : 0,
}

_REGEX_OPTION_INT: dict[int, int] = {
    1  : re.IGNORECASE,
    2  : re.MULTILINE,
    16 : re.DOTALL,
    32 : re.VERBOSE,
}

_RIGHT_TO_LEFT = 64
_MAX_STRING_EXPAND = 0x1000
_MAX_RANGES_EXPAND = 15


def _is_static_regex_call(node: Ps1InvokeMember) -> bool:
    return is_static_type_call(node, 'system.text.regularexpressions.regex')


def _parse_regex_options(node: Expression) -> tuple[int, bool] | None:
    """
    Parse a RegexOptions argument (string or integer) into Python re flags
    and a right_to_left boolean.
    """
    sv = string_value(node)
    if sv is not None:
        flags = 0
        right_to_left = False
        for part in sv.split(','):
            key = part.strip().lower()
            if not key:
                continue
            if key == 'righttoleft':
                right_to_left = True
                continue
            flag = _REGEX_OPTION_FLAGS.get(key)
            if flag is None:
                return None
            flags |= flag
        return flags, right_to_left
    if isinstance(node, Ps1IntegerLiteral):
        value = node.value
        right_to_left = bool(value & _RIGHT_TO_LEFT)
        flags = 0
        for bit, flag in _REGEX_OPTION_INT.items():
            if value & bit:
                flags |= flag
        return flags, right_to_left
    return None


def _iter_regex_matches(node: Ps1InvokeMember) -> Iterator[str] | None:
    """
    Yield matched strings from a call to

        [Regex]::Match/Matches(input, pattern[, options])

    Returns `None` if the arguments cannot be resolved.
    """
    if len(node.arguments) not in (2, 3):
        return None
    input = string_value(node.arguments[0])
    pattern = string_value(node.arguments[1])
    if input is None or pattern is None:
        return None
    if len(node.arguments) == 3:
        if (options := _parse_regex_options(node.arguments[2])) is None:
            return None
        flags, right_to_left = options
    else:
        flags, right_to_left = 0, False
    direction = (
        lambda m: m,
        lambda m: m[::-1],
    )[right_to_left]
    try:
        return (direction(m[0]) for m in re.finditer(pattern, direction(input), flags))
    except re.error:
        return None


def _compute_regex_matches(node: Ps1InvokeMember) -> list[str] | None:
    if it := _iter_regex_matches(node):
        return list(it)


def _compute_regex_match(node: Ps1InvokeMember) -> str | None:
    if it := _iter_regex_matches(node):
        return next(it, '')


_INTEGER_RESULT_TYPES = frozenset({
    'system.int16',
    'system.int32',
    'system.int64',
    'system.uint16',
    'system.uint32',
    'system.uint64',
    'system.byte',
    'system.sbyte',
})


def _foreach_extracts_value(sb: Ps1ScriptBlock) -> bool:
    """
    Check whether a ForEach scriptblock body is of the form `$_.Value`,
    `$_.Groups.Value`, or `$_.Groups.Captures.Groups.Value` — i.e. it
    extracts the string value from Match objects.
    """
    if sb.body is None or len(sb.body) != 1:
        return False
    stmt = sb.body[0]
    if not isinstance(stmt, Ps1ExpressionStatement) or stmt.expression is None:
        return False
    node = stmt.expression
    if not isinstance(node, Ps1Pipeline):
        expr = node
    elif len(node.elements) == 1 and node.elements[0].expression is not None:
        expr = node.elements[0].expression
    else:
        return False
    if not isinstance(expr, Ps1MemberAccess):
        return False
    member = expr.member if isinstance(expr.member, str) else None
    if member is None or member.lower() != 'value':
        return False
    inner = expr.object
    while isinstance(inner, Ps1MemberAccess):
        prop = inner.member if isinstance(inner.member, str) else None
        if prop is None or prop.lower() not in ('groups', 'captures'):
            return False
        inner = inner.object
    return isinstance(inner, Ps1Variable) and inner.name == '_'


def _escape_for_expandable(text: str) -> str:
    """
    Escape characters that are special inside double-quoted strings.
    """
    return text.replace('`', '``').replace('$', '`$')


def _variable_raw(var: Ps1Variable) -> str:
    """
    Produce the braced variable reference for use inside an expandable string.
    """
    prefix = '@' if var.splatted else '$'
    scope = var.scope.value
    if scope:
        return F'{prefix}{{{scope}:{var.name}}}'
    return F'{prefix}{{{var.name}}}'


def _variable_string_to_expandable(
    var: Ps1Variable,
    text: str,
    *,
    var_first: bool,
) -> Ps1ExpandableString:
    """
    Fold `$var + 'text'` or `'text' + $var` into a
    `refinery.lib.scripts.ps1.model.Ps1ExpandableString`.
    """
    escaped = _escape_for_expandable(text)
    var_raw = _variable_raw(var)
    text_part = Ps1StringLiteral(value=text, raw=F"'{text}'")
    if var_first:
        raw = F'"{var_raw}{escaped}"'
        parts = [var, text_part]
    else:
        raw = F'"{escaped}{var_raw}"'
        parts = [text_part, var]
    return Ps1ExpandableString(parts=parts, raw=raw)


def _resolve_index_values(index: Expression) -> int | list[int] | None:
    n = unwrap_integer(index)
    if n is not None:
        return n.value
    array = unwrap_to_array_literal(index)
    if array is not None:
        result: list[int] = []
        for elem in array.elements:
            n = unwrap_integer(elem)
            if n is None:
                return None
            result.append(n.value)
        return result
    return None


def _index_into_string(s: str, indices: int | list[int]) -> Expression | None:
    n = len(s)
    if isinstance(indices, int):
        if -n <= indices < n:
            return make_string_literal(s[indices])
        return None
    selected: list[Expression] = []
    for i in indices:
        if not (-n <= i < n):
            return None
        selected.append(make_string_literal(s[i]))
    return Ps1ArrayLiteral(elements=selected)


def _index_into_array(
    array: Ps1ArrayLiteral, indices: int | list[int],
) -> Expression | None:
    n = len(array.elements)
    if isinstance(indices, int):
        if -n <= indices < n:
            return array.elements[indices]
        return None
    selected: list[Expression] = []
    for i in indices:
        if not (-n <= i < n):
            return None
        selected.append(array.elements[i])
    return Ps1ArrayLiteral(elements=selected)


def _lookup_hashtable(ht: Ps1HashLiteral, index: Expression) -> Expression | None:
    key = string_value(index)
    if key is None:
        return None
    lower = key.lower()
    for pair_key, pair_value in ht.pairs:
        k = string_value(pair_key)
        if k is not None and k.lower() == lower:
            return pair_value
    return None


class Ps1ConstantFolding(LocalFunctionAwareTransformer):

    def visit_Ps1CommandInvocation(self, node: Ps1CommandInvocation):
        self.generic_visit(node)
        return None

    def visit_Ps1Pipeline(self, node: Ps1Pipeline):
        if len(node.elements) == 2:
            result = self._try_fold_regex_pipeline(node)
            if result is not None:
                return result
        self.generic_visit(node)
        return None

    @staticmethod
    def _fold_regex_call_result(
        invoke: Ps1InvokeMember, member_lower: str,
    ) -> Expression | None:
        if member_lower == 'matches':
            matches = _compute_regex_matches(invoke)
            if matches is not None:
                elements: list[Expression] = [make_string_literal(s) for s in matches]
                return Ps1ArrayLiteral(elements=elements)
        elif member_lower == 'match':
            result = _compute_regex_match(invoke)
            if result is not None:
                return make_string_literal(result)
        return None

    def _try_fold_regex_pipeline(self, node: Ps1Pipeline) -> Expression | None:
        first = node.elements[0].expression
        second_expr = node.elements[1].expression
        if not isinstance(first, Ps1InvokeMember) or not _is_static_regex_call(first):
            return None
        member = get_member_name(first.member)
        if member is None:
            return None
        sb = extract_foreach_scriptblock(second_expr) if second_expr else None
        if sb is None or not _foreach_extracts_value(sb):
            return None
        return self._fold_regex_call_result(first, member.lower())

    def visit_Ps1MemberAccess(self, node: Ps1MemberAccess):
        self.generic_visit(node)
        member = get_member_name(node.member)
        if member is None:
            return None
        obj = node.object
        if obj is None:
            return None
        member_type = resolve_member_type(obj, member)
        if member_type in _INTEGER_RESULT_TYPES:
            s = string_value(obj)
            if s is not None:
                return Ps1IntegerLiteral(value=len(s), raw=str(len(s)))
            array = unwrap_to_array_literal(obj)
            if array is not None:
                return Ps1IntegerLiteral(
                    value=len(array.elements), raw=str(len(array.elements)))
        if (
            string_value(obj) is not None
            or isinstance(obj, Ps1IntegerLiteral)
        ):
            if not is_known_member(obj, member):
                return Ps1Variable(name='Null')
        result = self._try_fold_regex_member_access(node, member)
        if result is not None:
            return result
        return None

    def _try_fold_regex_member_access(
        self, node: Ps1MemberAccess, member: str,
    ) -> Expression | None:
        chain: list[str] = [member]
        inner = node.object
        while isinstance(inner, Ps1MemberAccess):
            prop = get_member_name(inner.member)
            if prop is None:
                return None
            chain.append(prop)
            inner = inner.object
        chain.reverse()
        if not isinstance(inner, Ps1InvokeMember) or not _is_static_regex_call(inner):
            return None
        normalized = [c.lower() for c in chain]
        if normalized[-1] != 'value':
            return None
        for c in normalized[:-1]:
            if c not in ('groups', 'captures'):
                return None
        call_member = inner.member if isinstance(inner.member, str) else None
        if call_member is None:
            return None
        return self._fold_regex_call_result(inner, call_member.lower())

    @staticmethod
    def _try_join_regex_matches(operand: Expression) -> Expression | None:
        unwrapped = unwrap_parens(operand)
        if not isinstance(unwrapped, Ps1InvokeMember) or not _is_static_regex_call(unwrapped):
            return None
        member = unwrapped.member if isinstance(unwrapped.member, str) else None
        if member is None or member.lower() != 'matches':
            return None
        matches = _compute_regex_matches(unwrapped)
        if matches is None:
            return None
        return make_string_literal(''.join(matches))

    def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
        self.generic_visit(node)
        if node.operand is None:
            return None
        op = node.operator.lower()
        if op == '-join':
            return self._handle_unary_join(node)
        if op == '-bnot':
            n = unwrap_integer(node.operand)
            if n is not None:
                return Ps1IntegerLiteral(value=~n.value, raw=str(~n.value))
        if op in ('-not', '!'):
            truth = is_truthy(node.operand)
            if truth is not None:
                return Ps1Variable(name='False' if truth else 'True')
        return None

    def _handle_unary_join(self, node: Ps1UnaryExpression) -> Expression | None:
        operand = node.operand
        if operand is None:
            return None
        scalar = string_value(operand)
        if scalar is not None:
            return make_string_literal(scalar)
        result = self._try_join_regex_matches(operand)
        if result is not None:
            return result
        array = unwrap_to_array_literal(operand)
        if array is None:
            if isinstance(operand, Ps1ArrayExpression) and len(operand.body) == 1:
                stmt = operand.body[0]
                if isinstance(stmt, Ps1ExpressionStatement):
                    sv = string_value(stmt.expression) if stmt.expression else None
                    if sv is not None:
                        return make_string_literal(sv)
            return None
        args = collect_string_arguments(array)
        if args is None:
            return None
        return make_string_literal(''.join(args))

    def visit_Ps1RangeExpression(self, node: Ps1RangeExpression):
        self.generic_visit(node)
        if isinstance(node.parent, Ps1RangeExpression):
            return None
        lower = unwrap_integer(node.start)
        upper = unwrap_integer(node.end)
        if lower is None or upper is None:
            return None
        step = 1 if (b := upper.value) >= (a := lower.value) else -1
        count = abs(b - a) + 1
        if count > _MAX_RANGES_EXPAND:
            return None
        return Ps1ArrayLiteral(elements=[
            Ps1IntegerLiteral(value=v, raw=str(v)) for v in range(a, b + step, step)])

    def visit_Ps1IndexExpression(self, node: Ps1IndexExpression):
        self.generic_visit(node)
        if node.index is None or node.object is None:
            return None
        if isinstance(node.object, Ps1HashLiteral):
            return _lookup_hashtable(node.object, node.index)
        indices = _resolve_index_values(node.index)
        if indices is None:
            return None
        obj_str = string_value(node.object)
        if obj_str is not None:
            return _index_into_string(obj_str, indices)
        array = unwrap_to_array_literal(node.object)
        if array is not None:
            return _index_into_array(array, indices)
        return None

    def visit_Ps1ExpressionStatement(self, node: Ps1ExpressionStatement):
        self.generic_visit(node)
        var = is_array_reverse_call(node)
        if var is not None and self._try_apply_array_reverse(node, var):
            return node
        return None

    def _try_apply_array_reverse(
        self, node: Ps1ExpressionStatement, var: Ps1Variable,
    ) -> bool:
        body = get_body(node.parent)
        if body is None:
            return False
        try:
            idx = body.index(node)
        except ValueError:
            return False
        var_name = var.name.lower()
        for i in range(idx - 1, -1, -1):
            stmt = body[i]
            if not isinstance(stmt, Ps1ExpressionStatement):
                continue
            expr = stmt.expression
            if not isinstance(expr, Ps1AssignmentExpression):
                continue
            if expr.operator != '=':
                continue
            target = expr.target
            if not isinstance(target, Ps1Variable):
                continue
            if target.name.lower() != var_name:
                continue
            value = expr.value
            if isinstance(value, Ps1ArrayLiteral):
                value.elements.reverse()
                node.expression = None
                self.mark_changed()
                return True
            if isinstance(value, Ps1ArrayExpression) and len(value.body) == 1:
                inner = value.body[0]
                if (
                    isinstance(inner, Ps1ExpressionStatement)
                    and isinstance(inner.expression, Ps1ArrayLiteral)
                ):
                    inner.expression.elements.reverse()
                    node.expression = None
                    self.mark_changed()
                    return True
            sv = string_value(value)
            if sv is not None:
                replacement = make_string_literal(sv[::-1])
                replacement.parent = expr
                expr.value = replacement
                node.expression = None
                self.mark_changed()
                return True
            return False
        return False

    def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
        self.generic_visit(node)
        member_name = get_member_name(node.member)
        if member_name is None:
            return None
        lower = member_name.lower()
        return (
            self._try_fold_invoke_redirect(node, lower)
            or self._try_fold_instance_method(node, lower)
            or self._try_fold_static_method(node, lower)
        ) or None

    @staticmethod
    def _try_fold_invoke_redirect(
        node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        if lower == 'invoke' and isinstance(node.object, Ps1MemberAccess):
            return Ps1InvokeMember(
                offset=node.offset,
                object=node.object.object,
                member=node.object.member,
                arguments=node.arguments,
                access=node.object.access,
            )
        return None

    @staticmethod
    def _try_fold_instance_method(
        node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        obj_str = string_value(node.object) if node.object else None
        if obj_str is None:
            return None
        coerced: list[str | int] = []
        for arg in node.arguments:
            sv = string_value(arg)
            if sv is not None:
                coerced.append(sv)
                continue
            if isinstance(arg, Ps1IntegerLiteral):
                coerced.append(arg.value)
                continue
            return None
        try:
            result = apply_string_method(obj_str, lower, coerced)
        except StringMethodError:
            return None
        if isinstance(result, str):
            return make_string_literal(result)
        if isinstance(result, bool):
            return Ps1Variable(name='True' if result else 'False')
        if isinstance(result, int):
            return Ps1IntegerLiteral(value=result, raw=str(result))
        if isinstance(result, list):
            elements: list[Expression] = [make_string_literal(p) for p in result]
            return Ps1ArrayLiteral(elements=elements)
        return None

    def _try_fold_static_method(
        self, node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        if is_static_type_call(node, 'system.convert'):
            return self._try_fold_convert(node, lower)
        encoding_name = detect_encoding_chain(node)
        if encoding_name is not None:
            if len(node.arguments) == 1:
                arg = unwrap_single_paren(node.arguments[0])
                if isinstance(arg, Ps1ArrayExpression) and len(arg.body) == 1:
                    stmt = arg.body[0]
                    if isinstance(stmt, Ps1ExpressionStatement) and stmt.expression:
                        arg = stmt.expression
                int_values = collect_int_arguments(arg)
                if int_values is not None:
                    try:
                        raw_bytes = bytearray(int_values)
                    except (ValueError, OverflowError):
                        return None
                    encoding = ENCODING_MAP.get(
                        encoding_name.lower(), encoding_name)
                    try:
                        codecs.lookup(encoding)
                    except LookupError:
                        encoding = 'utf-8'
                    try:
                        decoded_str = raw_bytes.decode(encoding)
                    except Exception:
                        return None
                    return make_string_literal(decoded_str)
        if is_static_type_call(node, 'system.string'):
            if lower == 'concat' and len(node.arguments) >= 1:
                parts: list[str] = []
                for arg in node.arguments:
                    if (sv := string_value(arg)) is None:
                        break
                    parts.append(sv)
                else:
                    return make_string_literal(''.join(parts))
            if lower == 'join' and len(node.arguments) >= 2:
                separator = string_value(node.arguments[0])
                if separator is not None:
                    joined: list[str] = []
                    for arg in node.arguments[1:]:
                        if (sv := string_value(arg)) is None:
                            break
                        joined.append(sv)
                    else:
                        return make_string_literal(separator.join(joined))
                    if len(node.arguments) == 2:
                        array = unwrap_to_array_literal(node.arguments[1])
                        if array is not None:
                            args = collect_string_arguments(array)
                            if args is not None:
                                return make_string_literal(separator.join(args))
        if _is_static_regex_call(node) and lower == 'replace':
            return self._handle_regex_replace(node)
        if is_static_type_call(node, 'system.bitconverter') and lower == 'tostring':
            return self._try_fold_bitconverter_tostring(node)
        if (
            is_static_type_call(node, 'system.environment')
            and lower == 'getenvironmentvariable'
            and len(na := node.arguments) == 1
            and (_en := string_value(na[0])) is not None
            and (_ev := PS1_ENV_CONSTANTS.get(_en.lower())) is not None
        ):
            return make_string_literal(_ev)
        return None

    _CONVERT_INT_METHODS = {
        'tobyte'  : (0, 0xFF),
        'toint16' : (-0x8000, 0x7FFF),
        'toint32' : (-0x80000000, 0x7FFFFFFF),
        'toint64' : (-0x8000000000000000, 0x7FFFFFFFFFFFFFFF),
        'tosbyte' : (-0x80, 0x7F),
        'touint16': (0, 0xFFFF),
        'touint32': (0, 0xFFFFFFFF),
        'touint64': (0, 0xFFFFFFFFFFFFFFFF),
    }

    def _try_fold_convert(
        self, node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        if lower == 'frombase64string' and len(node.arguments) == 1:
            b64_str = string_value(node.arguments[0])
            if b64_str is not None:
                try:
                    decoded = base64.b64decode(b64_str)
                except Exception:
                    return None
                elements: list[Expression] = [
                    Ps1IntegerLiteral(value=b, raw=F'0x{b:02X}') for b in decoded
                ]
                array = Ps1ArrayLiteral(elements=elements)
                return Ps1ArrayExpression(
                    body=[Ps1ExpressionStatement(expression=array)])
        bounds = self._CONVERT_INT_METHODS.get(lower)
        if bounds is not None:
            return self._fold_convert_int(node, bounds)
        if lower == 'tochar':
            n = unwrap_integer(node.arguments[0]) if len(node.arguments) == 1 else None
            if n is not None:
                try:
                    return make_string_literal(chr(n.value))
                except (ValueError, OverflowError):
                    pass
        return None

    def _fold_convert_int(
        self, node: Ps1InvokeMember, bounds: tuple[int, int],
    ) -> Expression | None:
        lo, hi = bounds
        if len(node.arguments) == 1:
            n = unwrap_integer(node.arguments[0])
            if n is not None and lo <= n.value <= hi:
                return Ps1IntegerLiteral(value=n.value, raw=str(n.value))
            sv = string_value(node.arguments[0])
            if sv is not None:
                sv = sv.strip()
                try:
                    value = int(sv, 0)
                except (ValueError, OverflowError):
                    return None
                if lo <= value <= hi:
                    return Ps1IntegerLiteral(value=value, raw=str(value))
        elif len(node.arguments) == 2:
            sv = string_value(node.arguments[0])
            base_int = unwrap_integer(node.arguments[1])
            if sv is not None and base_int is not None and base_int.value in (2, 8, 10, 16):
                try:
                    value = int(sv, base_int.value)
                except (ValueError, OverflowError):
                    return None
                if lo <= value <= hi:
                    return Ps1IntegerLiteral(value=value, raw=str(value))
        return None

    @staticmethod
    def _try_fold_bitconverter_tostring(node: Ps1InvokeMember) -> Expression | None:
        if not node.arguments:
            return None
        data = collect_byte_array(node.arguments[0])
        if data is None:
            return None
        offset = 0
        length = len(data)
        if len(node.arguments) >= 2:
            n = unwrap_integer(node.arguments[1])
            if n is None:
                return None
            offset = n.value
        if len(node.arguments) >= 3:
            n = unwrap_integer(node.arguments[2])
            if n is None:
                return None
            length = n.value
        if offset < 0 or length < 0 or offset + length > len(data):
            return None
        segment = data[offset:offset + length]
        return make_string_literal('-'.join(F'{b:02X}' for b in segment))

    def _handle_regex_replace(self, node: Ps1InvokeMember) -> Expression | None:
        if len(node.arguments) not in (3, 4):
            return None
        input_str = string_value(node.arguments[0])
        pattern_str = string_value(node.arguments[1])
        replacement_str = string_value(node.arguments[2])
        if input_str is None or pattern_str is None or replacement_str is None:
            return None
        flags = 0
        right_to_left = False
        if len(node.arguments) == 4:
            opts = _parse_regex_options(node.arguments[3])
            if opts is None:
                return None
            flags, right_to_left = opts
        try:
            if right_to_left:
                result = re.sub(
                    pattern_str, lambda _: replacement_str, input_str[::-1], flags=flags)
                result = result[::-1]
            else:
                result = re.sub(
                    pattern_str, lambda _: replacement_str, input_str, flags=flags)
        except re.error:
            return None
        return make_string_literal(result)

    _ARITHMETIC_OPS = {
        '+'     : int.__add__,
        '-'     : int.__sub__,
        '*'     : int.__mul__,
        '/'     : int.__floordiv__,
        '%'     : int.__mod__,
        '-band' : int.__and__,
        '-bor'  : int.__or__,
        '-bxor' : int.__xor__,
        '-shl'  : int.__lshift__,
        '-shr'  : int.__rshift__,
    }

    def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
        self.generic_visit(node)
        op = node.operator.lower()
        if op == '-f':
            return self._handle_format(node)
        if op == '+':
            return self._handle_concat(node) or self._handle_arithmetic(node, op)
        if op == '*':
            return self._handle_string_multiply(node) or self._handle_arithmetic(node, op)
        if op == '-join':
            return self._handle_binary_join(node)
        if op in ('-replace', '-creplace', '-ireplace'):
            return self._handle_binary_replace(node, op)
        if op in ('-split', '-csplit', '-isplit'):
            return self._handle_binary_split(node, op)
        return self._handle_comparison(node, op) or self._handle_arithmetic(node, op)

    def _handle_arithmetic(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = unwrap_integer(node.left)
        right = unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = self._ARITHMETIC_OPS.get(op)
        if fn is None:
            return None
        try:
            result = fn(left.value, right.value)
        except (ZeroDivisionError, ValueError, OverflowError):
            return None
        return Ps1IntegerLiteral(value=result, raw=str(result))

    @staticmethod
    def _handle_string_multiply(node: Ps1BinaryExpression) -> Expression | None:
        s = string_value(node.left) if node.left else None
        n = unwrap_integer(node.right)
        if s is None or n is None:
            s = string_value(node.right) if node.right else None
            n = unwrap_integer(node.left)
        if s is None or n is None:
            return None
        count = n.value
        if count < 0:
            count = 0
        if len(s) * count > _MAX_STRING_EXPAND:
            return None
        return make_string_literal(s * count)

    def _handle_comparison(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = unwrap_integer(node.left)
        right = unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = COMPARISON_OPS.get(op)
        if fn is None:
            return None
        result = fn(left.value, right.value)
        return Ps1Variable(name='True' if result else 'False')

    def _handle_format(self, node: Ps1BinaryExpression) -> Expression | None:
        fmt_str = string_value(node.left) if node.left else None
        if fmt_str is None or node.right is None:
            return None
        args = collect_format_arguments(node.right)
        if args is None:
            return None
        result = apply_format_string(fmt_str, args)
        if result is None:
            return None
        return make_string_literal(result)

    def _handle_concat(self, node: Ps1BinaryExpression) -> Expression | None:
        left_str = string_value(node.left) if node.left else None
        right_str = string_value(node.right) if node.right else None
        if left_str is not None and right_str is not None:
            return make_string_literal(left_str + right_str)
        if right_str is not None and isinstance(node.left, Ps1BinaryExpression):
            if node.left.operator == '+':
                inner_right_str = string_value(node.left.right) if node.left.right else None
                if inner_right_str is not None:
                    nl = make_string_literal(inner_right_str + right_str)
                    nl.parent = node.left
                    node.left.right = nl
                    return node.left
        if right_str is not None and isinstance(node.left, Ps1ArrayLiteral):
            elements = list(node.left.elements)
            elements.append(make_string_literal(right_str))
            return Ps1ArrayLiteral(elements=elements)
        is_inner_concat = (
            isinstance(node.parent, Ps1BinaryExpression)
            and node.parent.operator == '+'
            and node.parent.left is node
        )
        if not is_inner_concat:
            if isinstance(node.left, Ps1Variable) and right_str is not None:
                return _variable_string_to_expandable(node.left, right_str, var_first=True)
            if isinstance(node.right, Ps1Variable) and left_str is not None:
                return _variable_string_to_expandable(node.right, left_str, var_first=False)
        return None

    def _handle_binary_join(self, node: Ps1BinaryExpression) -> Expression | None:
        separator = string_value(node.right) if node.right else None
        if separator is None or node.left is None:
            return None
        # Binary -Join on a scalar string is a no-op.
        scalar = string_value(node.left)
        if scalar is not None:
            return make_string_literal(scalar)
        array = unwrap_to_array_literal(node.left)
        if array is None:
            return None
        args = collect_string_arguments(array)
        if args is None:
            return None
        return make_string_literal(separator.join(args))

    def _handle_binary_replace(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        haystack = string_value(node.left) if node.left else None
        if haystack is None or node.right is None:
            return None
        if isinstance(node.right, Ps1ArrayLiteral) and len(node.right.elements) == 2:
            needle_str = string_value(node.right.elements[0])
            insert_str = string_value(node.right.elements[1])
        else:
            return None
        if needle_str is None or insert_str is None:
            return None
        flags = re.IGNORECASE if op != '-creplace' else 0
        try:
            result = re.sub(needle_str, lambda _: insert_str, haystack, flags=flags)
        except re.error:
            return None
        return make_string_literal(result)

    def _handle_binary_split(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        if node.right is None or node.left is None:
            return None
        pattern_str = string_value(node.right)
        if pattern_str is None:
            return None
        flags = re.IGNORECASE if op != '-csplit' else 0
        left_str = string_value(node.left)
        if left_str is not None:
            inputs = [left_str]
        else:
            array = unwrap_to_array_literal(node.left)
            if array is None:
                return None
            inputs_opt = collect_string_arguments(array)
            if inputs_opt is None:
                return None
            inputs = inputs_opt
        try:
            parts: list[str] = []
            for s in inputs:
                parts.extend(re.split(pattern_str, s, flags=flags))
        except re.error:
            return None
        elements: list[Expression] = [make_string_literal(p) for p in parts]
        return Ps1ArrayLiteral(elements=elements)

Classes

class Ps1ConstantFolding

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class Ps1ConstantFolding(LocalFunctionAwareTransformer):

    def visit_Ps1CommandInvocation(self, node: Ps1CommandInvocation):
        self.generic_visit(node)
        return None

    def visit_Ps1Pipeline(self, node: Ps1Pipeline):
        if len(node.elements) == 2:
            result = self._try_fold_regex_pipeline(node)
            if result is not None:
                return result
        self.generic_visit(node)
        return None

    @staticmethod
    def _fold_regex_call_result(
        invoke: Ps1InvokeMember, member_lower: str,
    ) -> Expression | None:
        if member_lower == 'matches':
            matches = _compute_regex_matches(invoke)
            if matches is not None:
                elements: list[Expression] = [make_string_literal(s) for s in matches]
                return Ps1ArrayLiteral(elements=elements)
        elif member_lower == 'match':
            result = _compute_regex_match(invoke)
            if result is not None:
                return make_string_literal(result)
        return None

    def _try_fold_regex_pipeline(self, node: Ps1Pipeline) -> Expression | None:
        first = node.elements[0].expression
        second_expr = node.elements[1].expression
        if not isinstance(first, Ps1InvokeMember) or not _is_static_regex_call(first):
            return None
        member = get_member_name(first.member)
        if member is None:
            return None
        sb = extract_foreach_scriptblock(second_expr) if second_expr else None
        if sb is None or not _foreach_extracts_value(sb):
            return None
        return self._fold_regex_call_result(first, member.lower())

    def visit_Ps1MemberAccess(self, node: Ps1MemberAccess):
        self.generic_visit(node)
        member = get_member_name(node.member)
        if member is None:
            return None
        obj = node.object
        if obj is None:
            return None
        member_type = resolve_member_type(obj, member)
        if member_type in _INTEGER_RESULT_TYPES:
            s = string_value(obj)
            if s is not None:
                return Ps1IntegerLiteral(value=len(s), raw=str(len(s)))
            array = unwrap_to_array_literal(obj)
            if array is not None:
                return Ps1IntegerLiteral(
                    value=len(array.elements), raw=str(len(array.elements)))
        if (
            string_value(obj) is not None
            or isinstance(obj, Ps1IntegerLiteral)
        ):
            if not is_known_member(obj, member):
                return Ps1Variable(name='Null')
        result = self._try_fold_regex_member_access(node, member)
        if result is not None:
            return result
        return None

    def _try_fold_regex_member_access(
        self, node: Ps1MemberAccess, member: str,
    ) -> Expression | None:
        chain: list[str] = [member]
        inner = node.object
        while isinstance(inner, Ps1MemberAccess):
            prop = get_member_name(inner.member)
            if prop is None:
                return None
            chain.append(prop)
            inner = inner.object
        chain.reverse()
        if not isinstance(inner, Ps1InvokeMember) or not _is_static_regex_call(inner):
            return None
        normalized = [c.lower() for c in chain]
        if normalized[-1] != 'value':
            return None
        for c in normalized[:-1]:
            if c not in ('groups', 'captures'):
                return None
        call_member = inner.member if isinstance(inner.member, str) else None
        if call_member is None:
            return None
        return self._fold_regex_call_result(inner, call_member.lower())

    @staticmethod
    def _try_join_regex_matches(operand: Expression) -> Expression | None:
        unwrapped = unwrap_parens(operand)
        if not isinstance(unwrapped, Ps1InvokeMember) or not _is_static_regex_call(unwrapped):
            return None
        member = unwrapped.member if isinstance(unwrapped.member, str) else None
        if member is None or member.lower() != 'matches':
            return None
        matches = _compute_regex_matches(unwrapped)
        if matches is None:
            return None
        return make_string_literal(''.join(matches))

    def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
        self.generic_visit(node)
        if node.operand is None:
            return None
        op = node.operator.lower()
        if op == '-join':
            return self._handle_unary_join(node)
        if op == '-bnot':
            n = unwrap_integer(node.operand)
            if n is not None:
                return Ps1IntegerLiteral(value=~n.value, raw=str(~n.value))
        if op in ('-not', '!'):
            truth = is_truthy(node.operand)
            if truth is not None:
                return Ps1Variable(name='False' if truth else 'True')
        return None

    def _handle_unary_join(self, node: Ps1UnaryExpression) -> Expression | None:
        operand = node.operand
        if operand is None:
            return None
        scalar = string_value(operand)
        if scalar is not None:
            return make_string_literal(scalar)
        result = self._try_join_regex_matches(operand)
        if result is not None:
            return result
        array = unwrap_to_array_literal(operand)
        if array is None:
            if isinstance(operand, Ps1ArrayExpression) and len(operand.body) == 1:
                stmt = operand.body[0]
                if isinstance(stmt, Ps1ExpressionStatement):
                    sv = string_value(stmt.expression) if stmt.expression else None
                    if sv is not None:
                        return make_string_literal(sv)
            return None
        args = collect_string_arguments(array)
        if args is None:
            return None
        return make_string_literal(''.join(args))

    def visit_Ps1RangeExpression(self, node: Ps1RangeExpression):
        self.generic_visit(node)
        if isinstance(node.parent, Ps1RangeExpression):
            return None
        lower = unwrap_integer(node.start)
        upper = unwrap_integer(node.end)
        if lower is None or upper is None:
            return None
        step = 1 if (b := upper.value) >= (a := lower.value) else -1
        count = abs(b - a) + 1
        if count > _MAX_RANGES_EXPAND:
            return None
        return Ps1ArrayLiteral(elements=[
            Ps1IntegerLiteral(value=v, raw=str(v)) for v in range(a, b + step, step)])

    def visit_Ps1IndexExpression(self, node: Ps1IndexExpression):
        self.generic_visit(node)
        if node.index is None or node.object is None:
            return None
        if isinstance(node.object, Ps1HashLiteral):
            return _lookup_hashtable(node.object, node.index)
        indices = _resolve_index_values(node.index)
        if indices is None:
            return None
        obj_str = string_value(node.object)
        if obj_str is not None:
            return _index_into_string(obj_str, indices)
        array = unwrap_to_array_literal(node.object)
        if array is not None:
            return _index_into_array(array, indices)
        return None

    def visit_Ps1ExpressionStatement(self, node: Ps1ExpressionStatement):
        self.generic_visit(node)
        var = is_array_reverse_call(node)
        if var is not None and self._try_apply_array_reverse(node, var):
            return node
        return None

    def _try_apply_array_reverse(
        self, node: Ps1ExpressionStatement, var: Ps1Variable,
    ) -> bool:
        body = get_body(node.parent)
        if body is None:
            return False
        try:
            idx = body.index(node)
        except ValueError:
            return False
        var_name = var.name.lower()
        for i in range(idx - 1, -1, -1):
            stmt = body[i]
            if not isinstance(stmt, Ps1ExpressionStatement):
                continue
            expr = stmt.expression
            if not isinstance(expr, Ps1AssignmentExpression):
                continue
            if expr.operator != '=':
                continue
            target = expr.target
            if not isinstance(target, Ps1Variable):
                continue
            if target.name.lower() != var_name:
                continue
            value = expr.value
            if isinstance(value, Ps1ArrayLiteral):
                value.elements.reverse()
                node.expression = None
                self.mark_changed()
                return True
            if isinstance(value, Ps1ArrayExpression) and len(value.body) == 1:
                inner = value.body[0]
                if (
                    isinstance(inner, Ps1ExpressionStatement)
                    and isinstance(inner.expression, Ps1ArrayLiteral)
                ):
                    inner.expression.elements.reverse()
                    node.expression = None
                    self.mark_changed()
                    return True
            sv = string_value(value)
            if sv is not None:
                replacement = make_string_literal(sv[::-1])
                replacement.parent = expr
                expr.value = replacement
                node.expression = None
                self.mark_changed()
                return True
            return False
        return False

    def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
        self.generic_visit(node)
        member_name = get_member_name(node.member)
        if member_name is None:
            return None
        lower = member_name.lower()
        return (
            self._try_fold_invoke_redirect(node, lower)
            or self._try_fold_instance_method(node, lower)
            or self._try_fold_static_method(node, lower)
        ) or None

    @staticmethod
    def _try_fold_invoke_redirect(
        node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        if lower == 'invoke' and isinstance(node.object, Ps1MemberAccess):
            return Ps1InvokeMember(
                offset=node.offset,
                object=node.object.object,
                member=node.object.member,
                arguments=node.arguments,
                access=node.object.access,
            )
        return None

    @staticmethod
    def _try_fold_instance_method(
        node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        obj_str = string_value(node.object) if node.object else None
        if obj_str is None:
            return None
        coerced: list[str | int] = []
        for arg in node.arguments:
            sv = string_value(arg)
            if sv is not None:
                coerced.append(sv)
                continue
            if isinstance(arg, Ps1IntegerLiteral):
                coerced.append(arg.value)
                continue
            return None
        try:
            result = apply_string_method(obj_str, lower, coerced)
        except StringMethodError:
            return None
        if isinstance(result, str):
            return make_string_literal(result)
        if isinstance(result, bool):
            return Ps1Variable(name='True' if result else 'False')
        if isinstance(result, int):
            return Ps1IntegerLiteral(value=result, raw=str(result))
        if isinstance(result, list):
            elements: list[Expression] = [make_string_literal(p) for p in result]
            return Ps1ArrayLiteral(elements=elements)
        return None

    def _try_fold_static_method(
        self, node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        if is_static_type_call(node, 'system.convert'):
            return self._try_fold_convert(node, lower)
        encoding_name = detect_encoding_chain(node)
        if encoding_name is not None:
            if len(node.arguments) == 1:
                arg = unwrap_single_paren(node.arguments[0])
                if isinstance(arg, Ps1ArrayExpression) and len(arg.body) == 1:
                    stmt = arg.body[0]
                    if isinstance(stmt, Ps1ExpressionStatement) and stmt.expression:
                        arg = stmt.expression
                int_values = collect_int_arguments(arg)
                if int_values is not None:
                    try:
                        raw_bytes = bytearray(int_values)
                    except (ValueError, OverflowError):
                        return None
                    encoding = ENCODING_MAP.get(
                        encoding_name.lower(), encoding_name)
                    try:
                        codecs.lookup(encoding)
                    except LookupError:
                        encoding = 'utf-8'
                    try:
                        decoded_str = raw_bytes.decode(encoding)
                    except Exception:
                        return None
                    return make_string_literal(decoded_str)
        if is_static_type_call(node, 'system.string'):
            if lower == 'concat' and len(node.arguments) >= 1:
                parts: list[str] = []
                for arg in node.arguments:
                    if (sv := string_value(arg)) is None:
                        break
                    parts.append(sv)
                else:
                    return make_string_literal(''.join(parts))
            if lower == 'join' and len(node.arguments) >= 2:
                separator = string_value(node.arguments[0])
                if separator is not None:
                    joined: list[str] = []
                    for arg in node.arguments[1:]:
                        if (sv := string_value(arg)) is None:
                            break
                        joined.append(sv)
                    else:
                        return make_string_literal(separator.join(joined))
                    if len(node.arguments) == 2:
                        array = unwrap_to_array_literal(node.arguments[1])
                        if array is not None:
                            args = collect_string_arguments(array)
                            if args is not None:
                                return make_string_literal(separator.join(args))
        if _is_static_regex_call(node) and lower == 'replace':
            return self._handle_regex_replace(node)
        if is_static_type_call(node, 'system.bitconverter') and lower == 'tostring':
            return self._try_fold_bitconverter_tostring(node)
        if (
            is_static_type_call(node, 'system.environment')
            and lower == 'getenvironmentvariable'
            and len(na := node.arguments) == 1
            and (_en := string_value(na[0])) is not None
            and (_ev := PS1_ENV_CONSTANTS.get(_en.lower())) is not None
        ):
            return make_string_literal(_ev)
        return None

    _CONVERT_INT_METHODS = {
        'tobyte'  : (0, 0xFF),
        'toint16' : (-0x8000, 0x7FFF),
        'toint32' : (-0x80000000, 0x7FFFFFFF),
        'toint64' : (-0x8000000000000000, 0x7FFFFFFFFFFFFFFF),
        'tosbyte' : (-0x80, 0x7F),
        'touint16': (0, 0xFFFF),
        'touint32': (0, 0xFFFFFFFF),
        'touint64': (0, 0xFFFFFFFFFFFFFFFF),
    }

    def _try_fold_convert(
        self, node: Ps1InvokeMember, lower: str,
    ) -> Expression | None:
        if lower == 'frombase64string' and len(node.arguments) == 1:
            b64_str = string_value(node.arguments[0])
            if b64_str is not None:
                try:
                    decoded = base64.b64decode(b64_str)
                except Exception:
                    return None
                elements: list[Expression] = [
                    Ps1IntegerLiteral(value=b, raw=F'0x{b:02X}') for b in decoded
                ]
                array = Ps1ArrayLiteral(elements=elements)
                return Ps1ArrayExpression(
                    body=[Ps1ExpressionStatement(expression=array)])
        bounds = self._CONVERT_INT_METHODS.get(lower)
        if bounds is not None:
            return self._fold_convert_int(node, bounds)
        if lower == 'tochar':
            n = unwrap_integer(node.arguments[0]) if len(node.arguments) == 1 else None
            if n is not None:
                try:
                    return make_string_literal(chr(n.value))
                except (ValueError, OverflowError):
                    pass
        return None

    def _fold_convert_int(
        self, node: Ps1InvokeMember, bounds: tuple[int, int],
    ) -> Expression | None:
        lo, hi = bounds
        if len(node.arguments) == 1:
            n = unwrap_integer(node.arguments[0])
            if n is not None and lo <= n.value <= hi:
                return Ps1IntegerLiteral(value=n.value, raw=str(n.value))
            sv = string_value(node.arguments[0])
            if sv is not None:
                sv = sv.strip()
                try:
                    value = int(sv, 0)
                except (ValueError, OverflowError):
                    return None
                if lo <= value <= hi:
                    return Ps1IntegerLiteral(value=value, raw=str(value))
        elif len(node.arguments) == 2:
            sv = string_value(node.arguments[0])
            base_int = unwrap_integer(node.arguments[1])
            if sv is not None and base_int is not None and base_int.value in (2, 8, 10, 16):
                try:
                    value = int(sv, base_int.value)
                except (ValueError, OverflowError):
                    return None
                if lo <= value <= hi:
                    return Ps1IntegerLiteral(value=value, raw=str(value))
        return None

    @staticmethod
    def _try_fold_bitconverter_tostring(node: Ps1InvokeMember) -> Expression | None:
        if not node.arguments:
            return None
        data = collect_byte_array(node.arguments[0])
        if data is None:
            return None
        offset = 0
        length = len(data)
        if len(node.arguments) >= 2:
            n = unwrap_integer(node.arguments[1])
            if n is None:
                return None
            offset = n.value
        if len(node.arguments) >= 3:
            n = unwrap_integer(node.arguments[2])
            if n is None:
                return None
            length = n.value
        if offset < 0 or length < 0 or offset + length > len(data):
            return None
        segment = data[offset:offset + length]
        return make_string_literal('-'.join(F'{b:02X}' for b in segment))

    def _handle_regex_replace(self, node: Ps1InvokeMember) -> Expression | None:
        if len(node.arguments) not in (3, 4):
            return None
        input_str = string_value(node.arguments[0])
        pattern_str = string_value(node.arguments[1])
        replacement_str = string_value(node.arguments[2])
        if input_str is None or pattern_str is None or replacement_str is None:
            return None
        flags = 0
        right_to_left = False
        if len(node.arguments) == 4:
            opts = _parse_regex_options(node.arguments[3])
            if opts is None:
                return None
            flags, right_to_left = opts
        try:
            if right_to_left:
                result = re.sub(
                    pattern_str, lambda _: replacement_str, input_str[::-1], flags=flags)
                result = result[::-1]
            else:
                result = re.sub(
                    pattern_str, lambda _: replacement_str, input_str, flags=flags)
        except re.error:
            return None
        return make_string_literal(result)

    _ARITHMETIC_OPS = {
        '+'     : int.__add__,
        '-'     : int.__sub__,
        '*'     : int.__mul__,
        '/'     : int.__floordiv__,
        '%'     : int.__mod__,
        '-band' : int.__and__,
        '-bor'  : int.__or__,
        '-bxor' : int.__xor__,
        '-shl'  : int.__lshift__,
        '-shr'  : int.__rshift__,
    }

    def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
        self.generic_visit(node)
        op = node.operator.lower()
        if op == '-f':
            return self._handle_format(node)
        if op == '+':
            return self._handle_concat(node) or self._handle_arithmetic(node, op)
        if op == '*':
            return self._handle_string_multiply(node) or self._handle_arithmetic(node, op)
        if op == '-join':
            return self._handle_binary_join(node)
        if op in ('-replace', '-creplace', '-ireplace'):
            return self._handle_binary_replace(node, op)
        if op in ('-split', '-csplit', '-isplit'):
            return self._handle_binary_split(node, op)
        return self._handle_comparison(node, op) or self._handle_arithmetic(node, op)

    def _handle_arithmetic(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = unwrap_integer(node.left)
        right = unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = self._ARITHMETIC_OPS.get(op)
        if fn is None:
            return None
        try:
            result = fn(left.value, right.value)
        except (ZeroDivisionError, ValueError, OverflowError):
            return None
        return Ps1IntegerLiteral(value=result, raw=str(result))

    @staticmethod
    def _handle_string_multiply(node: Ps1BinaryExpression) -> Expression | None:
        s = string_value(node.left) if node.left else None
        n = unwrap_integer(node.right)
        if s is None or n is None:
            s = string_value(node.right) if node.right else None
            n = unwrap_integer(node.left)
        if s is None or n is None:
            return None
        count = n.value
        if count < 0:
            count = 0
        if len(s) * count > _MAX_STRING_EXPAND:
            return None
        return make_string_literal(s * count)

    def _handle_comparison(self, node: Ps1BinaryExpression, op: str) -> Expression | None:
        left = unwrap_integer(node.left)
        right = unwrap_integer(node.right)
        if left is None or right is None:
            return None
        fn = COMPARISON_OPS.get(op)
        if fn is None:
            return None
        result = fn(left.value, right.value)
        return Ps1Variable(name='True' if result else 'False')

    def _handle_format(self, node: Ps1BinaryExpression) -> Expression | None:
        fmt_str = string_value(node.left) if node.left else None
        if fmt_str is None or node.right is None:
            return None
        args = collect_format_arguments(node.right)
        if args is None:
            return None
        result = apply_format_string(fmt_str, args)
        if result is None:
            return None
        return make_string_literal(result)

    def _handle_concat(self, node: Ps1BinaryExpression) -> Expression | None:
        left_str = string_value(node.left) if node.left else None
        right_str = string_value(node.right) if node.right else None
        if left_str is not None and right_str is not None:
            return make_string_literal(left_str + right_str)
        if right_str is not None and isinstance(node.left, Ps1BinaryExpression):
            if node.left.operator == '+':
                inner_right_str = string_value(node.left.right) if node.left.right else None
                if inner_right_str is not None:
                    nl = make_string_literal(inner_right_str + right_str)
                    nl.parent = node.left
                    node.left.right = nl
                    return node.left
        if right_str is not None and isinstance(node.left, Ps1ArrayLiteral):
            elements = list(node.left.elements)
            elements.append(make_string_literal(right_str))
            return Ps1ArrayLiteral(elements=elements)
        is_inner_concat = (
            isinstance(node.parent, Ps1BinaryExpression)
            and node.parent.operator == '+'
            and node.parent.left is node
        )
        if not is_inner_concat:
            if isinstance(node.left, Ps1Variable) and right_str is not None:
                return _variable_string_to_expandable(node.left, right_str, var_first=True)
            if isinstance(node.right, Ps1Variable) and left_str is not None:
                return _variable_string_to_expandable(node.right, left_str, var_first=False)
        return None

    def _handle_binary_join(self, node: Ps1BinaryExpression) -> Expression | None:
        separator = string_value(node.right) if node.right else None
        if separator is None or node.left is None:
            return None
        # Binary -Join on a scalar string is a no-op.
        scalar = string_value(node.left)
        if scalar is not None:
            return make_string_literal(scalar)
        array = unwrap_to_array_literal(node.left)
        if array is None:
            return None
        args = collect_string_arguments(array)
        if args is None:
            return None
        return make_string_literal(separator.join(args))

    def _handle_binary_replace(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        haystack = string_value(node.left) if node.left else None
        if haystack is None or node.right is None:
            return None
        if isinstance(node.right, Ps1ArrayLiteral) and len(node.right.elements) == 2:
            needle_str = string_value(node.right.elements[0])
            insert_str = string_value(node.right.elements[1])
        else:
            return None
        if needle_str is None or insert_str is None:
            return None
        flags = re.IGNORECASE if op != '-creplace' else 0
        try:
            result = re.sub(needle_str, lambda _: insert_str, haystack, flags=flags)
        except re.error:
            return None
        return make_string_literal(result)

    def _handle_binary_split(
        self, node: Ps1BinaryExpression, op: str,
    ) -> Expression | None:
        if node.right is None or node.left is None:
            return None
        pattern_str = string_value(node.right)
        if pattern_str is None:
            return None
        flags = re.IGNORECASE if op != '-csplit' else 0
        left_str = string_value(node.left)
        if left_str is not None:
            inputs = [left_str]
        else:
            array = unwrap_to_array_literal(node.left)
            if array is None:
                return None
            inputs_opt = collect_string_arguments(array)
            if inputs_opt is None:
                return None
            inputs = inputs_opt
        try:
            parts: list[str] = []
            for s in inputs:
                parts.extend(re.split(pattern_str, s, flags=flags))
        except re.error:
            return None
        elements: list[Expression] = [make_string_literal(p) for p in parts]
        return Ps1ArrayLiteral(elements=elements)

Ancestors

Methods

def visit_Ps1CommandInvocation(self, node)
Expand source code Browse git
def visit_Ps1CommandInvocation(self, node: Ps1CommandInvocation):
    self.generic_visit(node)
    return None
def visit_Ps1Pipeline(self, node)
Expand source code Browse git
def visit_Ps1Pipeline(self, node: Ps1Pipeline):
    if len(node.elements) == 2:
        result = self._try_fold_regex_pipeline(node)
        if result is not None:
            return result
    self.generic_visit(node)
    return None
def visit_Ps1MemberAccess(self, node)
Expand source code Browse git
def visit_Ps1MemberAccess(self, node: Ps1MemberAccess):
    self.generic_visit(node)
    member = get_member_name(node.member)
    if member is None:
        return None
    obj = node.object
    if obj is None:
        return None
    member_type = resolve_member_type(obj, member)
    if member_type in _INTEGER_RESULT_TYPES:
        s = string_value(obj)
        if s is not None:
            return Ps1IntegerLiteral(value=len(s), raw=str(len(s)))
        array = unwrap_to_array_literal(obj)
        if array is not None:
            return Ps1IntegerLiteral(
                value=len(array.elements), raw=str(len(array.elements)))
    if (
        string_value(obj) is not None
        or isinstance(obj, Ps1IntegerLiteral)
    ):
        if not is_known_member(obj, member):
            return Ps1Variable(name='Null')
    result = self._try_fold_regex_member_access(node, member)
    if result is not None:
        return result
    return None
def visit_Ps1UnaryExpression(self, node)
Expand source code Browse git
def visit_Ps1UnaryExpression(self, node: Ps1UnaryExpression):
    self.generic_visit(node)
    if node.operand is None:
        return None
    op = node.operator.lower()
    if op == '-join':
        return self._handle_unary_join(node)
    if op == '-bnot':
        n = unwrap_integer(node.operand)
        if n is not None:
            return Ps1IntegerLiteral(value=~n.value, raw=str(~n.value))
    if op in ('-not', '!'):
        truth = is_truthy(node.operand)
        if truth is not None:
            return Ps1Variable(name='False' if truth else 'True')
    return None
def visit_Ps1RangeExpression(self, node)
Expand source code Browse git
def visit_Ps1RangeExpression(self, node: Ps1RangeExpression):
    self.generic_visit(node)
    if isinstance(node.parent, Ps1RangeExpression):
        return None
    lower = unwrap_integer(node.start)
    upper = unwrap_integer(node.end)
    if lower is None or upper is None:
        return None
    step = 1 if (b := upper.value) >= (a := lower.value) else -1
    count = abs(b - a) + 1
    if count > _MAX_RANGES_EXPAND:
        return None
    return Ps1ArrayLiteral(elements=[
        Ps1IntegerLiteral(value=v, raw=str(v)) for v in range(a, b + step, step)])
def visit_Ps1IndexExpression(self, node)
Expand source code Browse git
def visit_Ps1IndexExpression(self, node: Ps1IndexExpression):
    self.generic_visit(node)
    if node.index is None or node.object is None:
        return None
    if isinstance(node.object, Ps1HashLiteral):
        return _lookup_hashtable(node.object, node.index)
    indices = _resolve_index_values(node.index)
    if indices is None:
        return None
    obj_str = string_value(node.object)
    if obj_str is not None:
        return _index_into_string(obj_str, indices)
    array = unwrap_to_array_literal(node.object)
    if array is not None:
        return _index_into_array(array, indices)
    return None
def visit_Ps1ExpressionStatement(self, node)
Expand source code Browse git
def visit_Ps1ExpressionStatement(self, node: Ps1ExpressionStatement):
    self.generic_visit(node)
    var = is_array_reverse_call(node)
    if var is not None and self._try_apply_array_reverse(node, var):
        return node
    return None
def visit_Ps1InvokeMember(self, node)
Expand source code Browse git
def visit_Ps1InvokeMember(self, node: Ps1InvokeMember):
    self.generic_visit(node)
    member_name = get_member_name(node.member)
    if member_name is None:
        return None
    lower = member_name.lower()
    return (
        self._try_fold_invoke_redirect(node, lower)
        or self._try_fold_instance_method(node, lower)
        or self._try_fold_static_method(node, lower)
    ) or None
def visit_Ps1BinaryExpression(self, node)
Expand source code Browse git
def visit_Ps1BinaryExpression(self, node: Ps1BinaryExpression):
    self.generic_visit(node)
    op = node.operator.lower()
    if op == '-f':
        return self._handle_format(node)
    if op == '+':
        return self._handle_concat(node) or self._handle_arithmetic(node, op)
    if op == '*':
        return self._handle_string_multiply(node) or self._handle_arithmetic(node, op)
    if op == '-join':
        return self._handle_binary_join(node)
    if op in ('-replace', '-creplace', '-ireplace'):
        return self._handle_binary_replace(node, op)
    if op in ('-split', '-csplit', '-isplit'):
        return self._handle_binary_split(node, op)
    return self._handle_comparison(node, op) or self._handle_arithmetic(node, op)