Module refinery.lib.scripts

Minimal unified AST base for script parsers. Provides abstract node types shared across language-specific parsers.

Expand source code Browse git
"""
Minimal unified AST base for script parsers. Provides abstract node types shared
across language-specific parsers.
"""
from __future__ import annotations

import dataclasses
import enum
import sys
import typing

from dataclasses import dataclass, field
from typing import Generator, Callable


class Kind(enum.IntEnum):
    ChildNode = 1
    ChildList = 2
    TupleList = 3


_SKIP_FIELDS = frozenset(('offset', 'parent', 'leading_comments'))

_child_fields_cache: dict[type, list[tuple[str, Kind]]] = {}


def _has_node_type(hint) -> bool:
    if isinstance(hint, type):
        return issubclass(hint, Node)
    origin = typing.get_origin(hint)
    if origin is type(int | str):
        return any(_has_node_type(a) for a in typing.get_args(hint))
    return False


def _classify_fields(node_type: type) -> list[tuple[str, Kind]]:
    try:
        return _child_fields_cache[node_type]
    except KeyError:
        pass
    result: list[tuple[str, Kind]] = []
    mod = sys.modules.get(node_type.__module__)
    globalns = vars(mod) if mod is not None else {}
    try:
        hints = typing.get_type_hints(node_type, globalns=globalns)
    except Exception:
        _child_fields_cache[node_type] = result
        return result
    for f in dataclasses.fields(node_type):
        if f.name in _SKIP_FIELDS:
            continue
        hint = hints.get(f.name)
        if hint is None:
            continue
        origin = typing.get_origin(hint)
        if origin is list:
            args = typing.get_args(hint)
            if not args:
                continue
            inner = args[0]
            inner_origin = typing.get_origin(inner)
            if inner_origin is tuple:
                inner_args = typing.get_args(inner)
                if any(_has_node_type(a) for a in inner_args):
                    result.append((f.name, Kind.TupleList))
            elif _has_node_type(inner):
                result.append((f.name, Kind.ChildList))
        elif _has_node_type(hint):
            result.append((f.name, Kind.ChildNode))
    _child_fields_cache[node_type] = result
    return result


@dataclass(repr=False)
class Node:
    """
    Base class for all AST nodes.
    """
    offset: int = -1
    parent: Node | None = field(default=None, compare=False)
    leading_comments: list[str] = field(default_factory=list, compare=False)

    def children(self) -> Generator[Node, None, None]:
        yield from ()

    def walk(self) -> Generator[Node, None, None]:
        stack: list[Node] = [self]
        while stack:
            node = stack.pop()
            yield node
            for child in node.children():
                stack.append(child)

    def _adopt(self, *nodes: Node | None):
        for node in nodes:
            if node is not None:
                node.parent = self

    def __repr__(self):
        try:
            return self.synthesize()
        except Exception:
            name = type(self).__name__
            return F'{name}@{self.offset}'

    def synthesize(self) -> str:
        from refinery.lib.scripts.ps1.synth import Ps1Synthesizer
        return Ps1Synthesizer().convert(self)


class Expression(Node):
    """
    Abstract base for all expression nodes.
    """
    pass


class Statement(Node):
    """
    Abstract base for all statement nodes.
    """
    pass


@dataclass(repr=False)
class Block(Node):
    """
    Ordered sequence of statements.
    """
    body: list[Statement] = field(default_factory=list)

    def __post_init__(self):
        self._adopt(*self.body)

    def children(self) -> Generator[Node, None, None]:
        yield from self.body


@dataclass(repr=False)
class Script(Node):
    """
    Top-level node representing an entire script.
    """
    body: list[Statement] = field(default_factory=list)

    def __post_init__(self):
        self._adopt(*self.body)

    def children(self) -> Generator[Node, None, None]:
        yield from self.body


class Visitor:
    """
    Dispatch-based tree walker. Subclasses define visit_ClassName methods;
    unhandled nodes fall through to generic_visit.
    """

    def __init__(self):
        self._dispatch: dict[type[Node], Callable[[Node], Node | None]] = {}

    def visit(self, node: Node) -> Node | None:
        t = type(node)
        try:
            handler = self._dispatch[t]
        except KeyError:
            handler = getattr(self, F'visit_{t.__name__}', self.generic_visit)
            self._dispatch[t] = handler
        return handler(node)

    def generic_visit(self, node: Node) -> Node | None:
        for child in node.children():
            self.visit(child)


class Transformer(Visitor):
    """
    In-place tree rewriter. Each visit method may return a replacement node
    or None to keep the original. Tracks whether any transformation was applied
    via the `changed` flag.
    """

    def __init__(self):
        super().__init__()
        self.changed = False

    def mark_changed(self):
        self.changed = True

    def generic_visit(self, node: Node):
        for field_name, kind in _classify_fields(type(node)):
            if kind == Kind.ChildNode:
                value = getattr(node, field_name)
                if isinstance(value, Node):
                    replacement = self.visit(value)
                    if replacement is not None:
                        replacement.parent = node
                        setattr(node, field_name, replacement)
                        self.mark_changed()
            elif kind == Kind.ChildList:
                items = getattr(node, field_name)
                new_list = []
                changed = False
                for item in items:
                    if isinstance(item, Node):
                        replacement = self.visit(item)
                        if replacement is not None:
                            replacement.parent = node
                            new_list.append(replacement)
                            changed = True
                        else:
                            new_list.append(item)
                    else:
                        new_list.append(item)
                if changed:
                    setattr(node, field_name, new_list)
                    self.mark_changed()
            elif kind == Kind.TupleList:
                items = getattr(node, field_name)
                new_list = []
                changed = False
                for item in items:
                    new_tuple = []
                    tuple_changed = False
                    for elem in item:
                        if isinstance(elem, Node):
                            replacement = self.visit(elem)
                            if replacement is not None:
                                replacement.parent = node
                                new_tuple.append(replacement)
                                tuple_changed = True
                            else:
                                new_tuple.append(elem)
                        else:
                            new_tuple.append(elem)
                    new_list.append(tuple(new_tuple) if tuple_changed else item)
                    changed = changed or tuple_changed
                if changed:
                    setattr(node, field_name, new_list)
                    self.mark_changed()
        return None

Sub-modules

refinery.lib.scripts.bat

Set Statement …

refinery.lib.scripts.guess
refinery.lib.scripts.js
refinery.lib.scripts.pipeline

Dependency-tree-based deobfuscation scheduler …

refinery.lib.scripts.ps1

PowerShell script parser for Binary Refinery.

refinery.lib.scripts.vba

VBA script parser for Binary Refinery.

refinery.lib.scripts.win32const

Default Windows environment variable definitions for script emulation.

Classes

class Kind (*args, **kwds)

Enum where members are also (and must be) ints

Expand source code Browse git
class Kind(enum.IntEnum):
    ChildNode = 1
    ChildList = 2
    TupleList = 3

Ancestors

  • enum.IntEnum
  • builtins.int
  • enum.ReprEnum
  • enum.Enum

Class variables

var ChildNode

The type of the None singleton.

var ChildList

The type of the None singleton.

var TupleList

The type of the None singleton.

class Node (offset=-1, parent=None, leading_comments=<factory>)

Base class for all AST nodes.

Expand source code Browse git
@dataclass(repr=False)
class Node:
    """
    Base class for all AST nodes.
    """
    offset: int = -1
    parent: Node | None = field(default=None, compare=False)
    leading_comments: list[str] = field(default_factory=list, compare=False)

    def children(self) -> Generator[Node, None, None]:
        yield from ()

    def walk(self) -> Generator[Node, None, None]:
        stack: list[Node] = [self]
        while stack:
            node = stack.pop()
            yield node
            for child in node.children():
                stack.append(child)

    def _adopt(self, *nodes: Node | None):
        for node in nodes:
            if node is not None:
                node.parent = self

    def __repr__(self):
        try:
            return self.synthesize()
        except Exception:
            name = type(self).__name__
            return F'{name}@{self.offset}'

    def synthesize(self) -> str:
        from refinery.lib.scripts.ps1.synth import Ps1Synthesizer
        return Ps1Synthesizer().convert(self)

Subclasses

Instance variables

var leading_comments

The type of the None singleton.

var offset

The type of the None singleton.

var parent

The type of the None singleton.

Methods

def children(self)
Expand source code Browse git
def children(self) -> Generator[Node, None, None]:
    yield from ()
def walk(self)
Expand source code Browse git
def walk(self) -> Generator[Node, None, None]:
    stack: list[Node] = [self]
    while stack:
        node = stack.pop()
        yield node
        for child in node.children():
            stack.append(child)
def synthesize(self)
Expand source code Browse git
def synthesize(self) -> str:
    from refinery.lib.scripts.ps1.synth import Ps1Synthesizer
    return Ps1Synthesizer().convert(self)
class Expression (offset=-1, parent=None, leading_comments=<factory>)

Abstract base for all expression nodes.

Expand source code Browse git
class Expression(Node):
    """
    Abstract base for all expression nodes.
    """
    pass

Ancestors

Subclasses

Inherited members

class Statement (offset=-1, parent=None, leading_comments=<factory>)

Abstract base for all statement nodes.

Expand source code Browse git
class Statement(Node):
    """
    Abstract base for all statement nodes.
    """
    pass

Ancestors

Subclasses

Inherited members

class Block (offset=-1, parent=None, leading_comments=<factory>, body=<factory>)

Ordered sequence of statements.

Expand source code Browse git
@dataclass(repr=False)
class Block(Node):
    """
    Ordered sequence of statements.
    """
    body: list[Statement] = field(default_factory=list)

    def __post_init__(self):
        self._adopt(*self.body)

    def children(self) -> Generator[Node, None, None]:
        yield from self.body

Ancestors

Instance variables

var body

The type of the None singleton.

Methods

def children(self)
Expand source code Browse git
def children(self) -> Generator[Node, None, None]:
    yield from self.body

Inherited members

class Script (offset=-1, parent=None, leading_comments=<factory>, body=<factory>)

Top-level node representing an entire script.

Expand source code Browse git
@dataclass(repr=False)
class Script(Node):
    """
    Top-level node representing an entire script.
    """
    body: list[Statement] = field(default_factory=list)

    def __post_init__(self):
        self._adopt(*self.body)

    def children(self) -> Generator[Node, None, None]:
        yield from self.body

Ancestors

Subclasses

Instance variables

var body

The type of the None singleton.

Methods

def children(self)
Expand source code Browse git
def children(self) -> Generator[Node, None, None]:
    yield from self.body

Inherited members

class Visitor

Dispatch-based tree walker. Subclasses define visit_ClassName methods; unhandled nodes fall through to generic_visit.

Expand source code Browse git
class Visitor:
    """
    Dispatch-based tree walker. Subclasses define visit_ClassName methods;
    unhandled nodes fall through to generic_visit.
    """

    def __init__(self):
        self._dispatch: dict[type[Node], Callable[[Node], Node | None]] = {}

    def visit(self, node: Node) -> Node | None:
        t = type(node)
        try:
            handler = self._dispatch[t]
        except KeyError:
            handler = getattr(self, F'visit_{t.__name__}', self.generic_visit)
            self._dispatch[t] = handler
        return handler(node)

    def generic_visit(self, node: Node) -> Node | None:
        for child in node.children():
            self.visit(child)

Subclasses

Methods

def visit(self, node)
Expand source code Browse git
def visit(self, node: Node) -> Node | None:
    t = type(node)
    try:
        handler = self._dispatch[t]
    except KeyError:
        handler = getattr(self, F'visit_{t.__name__}', self.generic_visit)
        self._dispatch[t] = handler
    return handler(node)
def generic_visit(self, node)
Expand source code Browse git
def generic_visit(self, node: Node) -> Node | None:
    for child in node.children():
        self.visit(child)
class Transformer

In-place tree rewriter. Each visit method may return a replacement node or None to keep the original. Tracks whether any transformation was applied via the changed flag.

Expand source code Browse git
class Transformer(Visitor):
    """
    In-place tree rewriter. Each visit method may return a replacement node
    or None to keep the original. Tracks whether any transformation was applied
    via the `changed` flag.
    """

    def __init__(self):
        super().__init__()
        self.changed = False

    def mark_changed(self):
        self.changed = True

    def generic_visit(self, node: Node):
        for field_name, kind in _classify_fields(type(node)):
            if kind == Kind.ChildNode:
                value = getattr(node, field_name)
                if isinstance(value, Node):
                    replacement = self.visit(value)
                    if replacement is not None:
                        replacement.parent = node
                        setattr(node, field_name, replacement)
                        self.mark_changed()
            elif kind == Kind.ChildList:
                items = getattr(node, field_name)
                new_list = []
                changed = False
                for item in items:
                    if isinstance(item, Node):
                        replacement = self.visit(item)
                        if replacement is not None:
                            replacement.parent = node
                            new_list.append(replacement)
                            changed = True
                        else:
                            new_list.append(item)
                    else:
                        new_list.append(item)
                if changed:
                    setattr(node, field_name, new_list)
                    self.mark_changed()
            elif kind == Kind.TupleList:
                items = getattr(node, field_name)
                new_list = []
                changed = False
                for item in items:
                    new_tuple = []
                    tuple_changed = False
                    for elem in item:
                        if isinstance(elem, Node):
                            replacement = self.visit(elem)
                            if replacement is not None:
                                replacement.parent = node
                                new_tuple.append(replacement)
                                tuple_changed = True
                            else:
                                new_tuple.append(elem)
                        else:
                            new_tuple.append(elem)
                    new_list.append(tuple(new_tuple) if tuple_changed else item)
                    changed = changed or tuple_changed
                if changed:
                    setattr(node, field_name, new_list)
                    self.mark_changed()
        return None

Ancestors

Subclasses

Methods

def mark_changed(self)
Expand source code Browse git
def mark_changed(self):
    self.changed = True
def generic_visit(self, node)
Expand source code Browse git
def generic_visit(self, node: Node):
    for field_name, kind in _classify_fields(type(node)):
        if kind == Kind.ChildNode:
            value = getattr(node, field_name)
            if isinstance(value, Node):
                replacement = self.visit(value)
                if replacement is not None:
                    replacement.parent = node
                    setattr(node, field_name, replacement)
                    self.mark_changed()
        elif kind == Kind.ChildList:
            items = getattr(node, field_name)
            new_list = []
            changed = False
            for item in items:
                if isinstance(item, Node):
                    replacement = self.visit(item)
                    if replacement is not None:
                        replacement.parent = node
                        new_list.append(replacement)
                        changed = True
                    else:
                        new_list.append(item)
                else:
                    new_list.append(item)
            if changed:
                setattr(node, field_name, new_list)
                self.mark_changed()
        elif kind == Kind.TupleList:
            items = getattr(node, field_name)
            new_list = []
            changed = False
            for item in items:
                new_tuple = []
                tuple_changed = False
                for elem in item:
                    if isinstance(elem, Node):
                        replacement = self.visit(elem)
                        if replacement is not None:
                            replacement.parent = node
                            new_tuple.append(replacement)
                            tuple_changed = True
                        else:
                            new_tuple.append(elem)
                    else:
                        new_tuple.append(elem)
                new_list.append(tuple(new_tuple) if tuple_changed else item)
                changed = changed or tuple_changed
            if changed:
                setattr(node, field_name, new_list)
                self.mark_changed()
    return None