Module refinery.lib.batch

Set Statement

There are two kinds of set statement: The quoted and the unquoted set. A quoted set looks like this:

set "name=var" (...)

It is interpreted as follows:

  • Everything between the first and the last quote in the command is extracted.
  • The resulting string is split at the first equals symbol.
  • The LHS is the variable name.
  • The RHS is unescaped once not respecting quotes, then becomes the variable content.

Examples

set "name="a"^^"b"c echo %name% "a"^"b

set "name="a"^^"b"c echo %name% a"^"b

Note how the trailing c is always discarded because it occurs after the last quote. The unquoted set looks like this:

set name=var

It is parsed as follows:

  • The entire command is parsed and unescaped respecting quotes as usual.
  • The set expression starts with the first non-whitespace character after the set keyword.
  • This expression is split at the first equals symbol.
  • The LHS is the variable name.
  • The RHS is unescaped once respecting quotes, then becomes the variable content.

Input redirection may occur in a set line, basically anywhere:

> set 1>"NUL" "var=val
> echo %var%
val
Expand source code Browse git
"""

## Set Statement

There are two kinds of set statement:
The quoted and the unquoted set.
A quoted set looks like this:

    set "name=var" (...)

It is interpreted as follows:

- Everything between the first and the last quote in the command is extracted.
- The resulting string is split at the first equals symbol.
- The LHS is the variable name.
- The RHS is unescaped once **not** respecting quotes, then becomes the variable content.

Examples:

    > set  "name="a"^^"b"c
    > echo %name%
    "a"^"b

    > set  "name="a"^^"b"c
    > echo %name%
    a"^"b

Note how the trailing c is always discarded because it occurs after the last quote.
The unquoted set looks like this:

    set name=var

It is parsed as follows:

- The entire command is parsed and unescaped respecting quotes as usual.
- The set expression starts with the first non-whitespace character after the set keyword.
- This expression is split at the first equals symbol.
- The LHS is the variable name.
- The RHS is unescaped once respecting quotes, then becomes the variable content.

Input redirection may occur in a set line, basically anywhere:

    > set 1>"NUL" "var=val
    > echo %var%
    val
"""
from __future__ import annotations

from .emulator import BatchEmulator
from .lexer import BatchLexer
from .parser import BatchParser
from .state import BatchState

__all__ = [
    'BatchEmulator',
    'BatchLexer',
    'BatchParser',
    'BatchState',
]

Sub-modules

refinery.lib.batch.const
refinery.lib.batch.emulator
refinery.lib.batch.lexer
refinery.lib.batch.model
refinery.lib.batch.parser
refinery.lib.batch.state
refinery.lib.batch.util

Classes

class BatchEmulator (data, state=None)
Expand source code Browse git
class BatchEmulator:

    class _register:
        handlers: ClassVar[dict[type[AstNode], Callable[[BatchEmulator, AstNode], Generator[str]]]] = {}

        def __init__(self, node_type: type[AstNode]):
            self.node_type = node_type

        def __call__(self, handler):
            self.handlers[self.node_type] = handler
            return handler

    def __init__(self, data: str | buf | BatchParser, state: BatchState | None = None):
        self.stack = []
        self.parser = BatchParser(data, state)

    @property
    def state(self):
        return self.parser.state

    @property
    def environment(self):
        return self.state.environment

    @property
    def delayexpand(self):
        return self.state.delayexpand

    def delay_expand(self, block: str | RedirectIO):
        if isinstance(block, RedirectIO):
            return block

        def expansion(match: re.Match[str]):
            name = match.group(1)
            return self.environment.get(name.upper(), '')

        return re.sub(r'!([^!:\n]*)!', expansion, block)

    def execute_set(self, cmd: EmulatorCommand):
        if not (args := cmd.args):
            raise EmulatorException('Empty SET instruction')

        arithmetic = False
        quote_mode = False

        if args[0].upper() == '/P':
            raise NotImplementedError('Prompt SET not implemented.')
        elif args[0].upper() == '/A':
            arithmetic = True
        elif len(args) not in (1, 3):
            raise EmulatorException(F'SET instruction with {len(args)} arguments unexpected.')

        if arithmetic:
            integers = {}
            updated = {}
            assignment = ''.join(args[1:])
            for name, value in self.environment.items():
                try:
                    integers[name] = batchint(value)
                except ValueError:
                    pass
            for assignment in assignment.split(','):
                assignment = assignment.strip()
                name, _, expression = assignment.partition('=')
                expression = cautious_eval_or_default(expression, environment=integers)
                if expression is not None:
                    integers[name] = expression
                    updated[name] = str(expression)
                self.environment.update(updated)
        else:
            if (n := len(args)) >= 2 and args[1] == '=':
                name, _, content = args
            elif (assignment := args[-1]).startswith('"'):
                if n != 1:
                    raise EmulatorException('Invalid SET from Lexer.')
                quote_mode = True
                assignment, _, unquoted = assignment[1:].rpartition('"')
                assignment = assignment or unquoted
                name, _, content = assignment.partition('=')
            else:
                name, _, content = ''.join(args).partition('=')
            name = name.upper()
            _, content = uncaret(content, quote_mode)
            if not content:
                self.environment.pop(name, None)
            else:
                self.environment[name] = content

    def execute_command(self, ast_command: AstCommand):
        if self.delayexpand:
            ast_command.tokens[:] = (self.delay_expand(token) for token in ast_command.tokens)
        command = EmulatorCommand(ast_command)
        verb = command.verb.upper().strip()
        if verb == 'SET':
            self.execute_set(command)
        elif verb == 'GOTO':
            label, *_ = command.argument_string.split(maxsplit=1)
            if label.startswith(':'):
                if label.upper() == ':EOF':
                    raise Exit(self.state.ec, False)
                label = label[1:]
            raise Goto(label)
        elif verb == 'CALL':
            empty, colon, label = command.argument_string.partition(':')
            if empty or not colon:
                raise EmulatorException(F'Invalid CALL label: {label}')
            try:
                offset = self.parser.lexer.labels[label.upper()]
            except KeyError as KE:
                raise InvalidLabel(label) from KE
            emu = BatchEmulator(self.parser)
            yield from emu.emulate(offset, called=True)
        elif verb == 'SETLOCAL':
            setting = command.argument_string.strip().upper()
            delay = {
                'DISABLEDELAYEDEXPANSION': False,
                'ENABLEDELAYEDEXPANSION' : True,
            }.get(setting, self.state.delayexpand)
            cmdxt = {
                'DISABLEEXTENSIONS': False,
                'ENABLEEXTENSIONS' : True,
            }.get(setting, self.state.ext_setting)
            self.state.delayexpands.append(delay)
            self.state.ext_settings.append(cmdxt)
            self.state.environments.append(dict(self.environment))
        elif verb == 'ENDLOCAL' and len(self.state.environments) > 1:
            self.state.environments.pop()
            self.state.delayexpands.pop()
        elif verb == 'EXIT':
            it = iter(command.args)
            exit = True
            token = 0
            for arg in it:
                if arg.upper() == '/B':
                    exit = False
                    continue
                token = arg
                break
            try:
                code = int(token)
            except ValueError:
                code = 0
            raise Exit(code, exit)
        elif verb == 'CD' or verb == 'CHDIR':
            self.state.cwd = command.argument_string
        elif verb == 'PUSHD':
            directory = command.argument_string
            self.state.dirstack.append(self.cwd)
            self.cwd = directory.rstrip()
        elif verb == 'POPD':
            try:
                self.state.cwd = self.state.dirstack.pop()
            except IndexError:
                pass
        elif verb == 'ECHO':
            for io in command.redirects:
                if io.type == Redirect.In:
                    continue
                if isinstance(path := io.target, str):
                    path = unquote(path.lstrip())
                    method = (
                        self.state.append_file
                    ) if io.type == Redirect.OutAppend else (
                        self.state.create_file
                    )
                    method(path, command.argument_string)
                break
            else:
                yield str(command)
        else:
            yield str(command)

    @_register(AstPipeline)
    def emulate_pipeline(self, pipeline: AstPipeline):
        for part in pipeline.parts:
            yield from self.execute_command(part)

    @_register(AstSequence)
    def emulate_sequence(self, sequence: AstSequence):
        yield from self.emulate_statement(sequence.head)
        for cs in sequence.tail:
            if cs.condition == AstCondition.Failure:
                if self.state.ec == 0:
                    continue
            if cs.condition == AstCondition.Success:
                if self.state.ec != 0:
                    continue
            yield from self.emulate_statement(cs.statement)

    @_register(AstIf)
    def emulate_if(self, _if: AstIf):
        if _if.variant == AstIfVariant.ErrorLevel:
            condition = _if.var_int <= self.state.ec
        elif _if.variant == AstIfVariant.CmdExtVersion:
            condition = _if.var_int <= self.state.extensions_version
        elif _if.variant == AstIfVariant.Exist:
            condition = self.state.exists_file(_if.var_str)
        elif _if.variant == AstIfVariant.Defined:
            condition = _if.var_str.upper() in self.state.environment
        else:
            lhs = _if.lhs
            rhs = _if.rhs
            cmp = _if.cmp
            assert lhs is not None
            assert rhs is not None
            if cmp == AstIfCmp.STR:
                if _if.casefold:
                    if isinstance(lhs, str):
                        lhs = lhs.casefold()
                    if isinstance(rhs, str):
                        rhs = rhs.casefold()
                condition = lhs == rhs
            elif cmp == AstIfCmp.GTR:
                condition = lhs > rhs
            elif cmp == AstIfCmp.GEQ:
                condition = lhs >= rhs
            elif cmp == AstIfCmp.NEQ:
                condition = lhs != rhs
            elif cmp == AstIfCmp.EQU:
                condition = lhs == rhs
            elif cmp == AstIfCmp.LSS:
                condition = lhs < rhs
            elif cmp == AstIfCmp.LEQ:
                condition = lhs <= rhs
            else:
                raise RuntimeError(cmp)
        if _if.negated:
            condition = not condition

        if condition:
            yield from self.emulate_statement(_if.then_do)
        elif (_else := _if.else_do):
            yield from self.emulate_statement(_else)

    @_register(AstFor)
    def emulate_for(self, _for: AstFor):
        yield from ()

    @_register(AstGroup)
    def emulate_group(self, group: AstGroup):
        for sequence in group.sequences:
            yield from self.emulate_sequence(sequence)

    @_register(AstLabel)
    def emulate_label(self, label: AstLabel):
        yield from ()

    def emulate_statement(self, statement: AstStatement):
        try:
            handler = self._register.handlers[statement.__class__]
        except KeyError:
            raise RuntimeError(statement)
        yield from handler(self, statement)

    def emulate(self, offset: int = 0, name: str | None = None, command_line: str = '', called: bool = False):
        if name:
            self.state.name = name
        self.state.command_line = command_line
        length = len(self.parser.lexer.code)
        labels = self.parser.lexer.labels

        while offset < length:
            try:
                for sequence in self.parser.parse(offset):
                    yield from self.emulate_sequence(sequence)
            except Goto as goto:
                try:
                    offset = labels[goto.label.upper()]
                except KeyError:
                    raise InvalidLabel(goto.label) from goto
                continue
            except Exit as exit:
                self.state.ec = exit.code
                if exit.exit and called:
                    raise
                else:
                    break
            else:
                break

Instance variables

var state
Expand source code Browse git
@property
def state(self):
    return self.parser.state
var environment
Expand source code Browse git
@property
def environment(self):
    return self.state.environment
var delayexpand
Expand source code Browse git
@property
def delayexpand(self):
    return self.state.delayexpand

Methods

def delay_expand(self, block)
Expand source code Browse git
def delay_expand(self, block: str | RedirectIO):
    if isinstance(block, RedirectIO):
        return block

    def expansion(match: re.Match[str]):
        name = match.group(1)
        return self.environment.get(name.upper(), '')

    return re.sub(r'!([^!:\n]*)!', expansion, block)
def execute_set(self, cmd)
Expand source code Browse git
def execute_set(self, cmd: EmulatorCommand):
    if not (args := cmd.args):
        raise EmulatorException('Empty SET instruction')

    arithmetic = False
    quote_mode = False

    if args[0].upper() == '/P':
        raise NotImplementedError('Prompt SET not implemented.')
    elif args[0].upper() == '/A':
        arithmetic = True
    elif len(args) not in (1, 3):
        raise EmulatorException(F'SET instruction with {len(args)} arguments unexpected.')

    if arithmetic:
        integers = {}
        updated = {}
        assignment = ''.join(args[1:])
        for name, value in self.environment.items():
            try:
                integers[name] = batchint(value)
            except ValueError:
                pass
        for assignment in assignment.split(','):
            assignment = assignment.strip()
            name, _, expression = assignment.partition('=')
            expression = cautious_eval_or_default(expression, environment=integers)
            if expression is not None:
                integers[name] = expression
                updated[name] = str(expression)
            self.environment.update(updated)
    else:
        if (n := len(args)) >= 2 and args[1] == '=':
            name, _, content = args
        elif (assignment := args[-1]).startswith('"'):
            if n != 1:
                raise EmulatorException('Invalid SET from Lexer.')
            quote_mode = True
            assignment, _, unquoted = assignment[1:].rpartition('"')
            assignment = assignment or unquoted
            name, _, content = assignment.partition('=')
        else:
            name, _, content = ''.join(args).partition('=')
        name = name.upper()
        _, content = uncaret(content, quote_mode)
        if not content:
            self.environment.pop(name, None)
        else:
            self.environment[name] = content
def execute_command(self, ast_command)
Expand source code Browse git
def execute_command(self, ast_command: AstCommand):
    if self.delayexpand:
        ast_command.tokens[:] = (self.delay_expand(token) for token in ast_command.tokens)
    command = EmulatorCommand(ast_command)
    verb = command.verb.upper().strip()
    if verb == 'SET':
        self.execute_set(command)
    elif verb == 'GOTO':
        label, *_ = command.argument_string.split(maxsplit=1)
        if label.startswith(':'):
            if label.upper() == ':EOF':
                raise Exit(self.state.ec, False)
            label = label[1:]
        raise Goto(label)
    elif verb == 'CALL':
        empty, colon, label = command.argument_string.partition(':')
        if empty or not colon:
            raise EmulatorException(F'Invalid CALL label: {label}')
        try:
            offset = self.parser.lexer.labels[label.upper()]
        except KeyError as KE:
            raise InvalidLabel(label) from KE
        emu = BatchEmulator(self.parser)
        yield from emu.emulate(offset, called=True)
    elif verb == 'SETLOCAL':
        setting = command.argument_string.strip().upper()
        delay = {
            'DISABLEDELAYEDEXPANSION': False,
            'ENABLEDELAYEDEXPANSION' : True,
        }.get(setting, self.state.delayexpand)
        cmdxt = {
            'DISABLEEXTENSIONS': False,
            'ENABLEEXTENSIONS' : True,
        }.get(setting, self.state.ext_setting)
        self.state.delayexpands.append(delay)
        self.state.ext_settings.append(cmdxt)
        self.state.environments.append(dict(self.environment))
    elif verb == 'ENDLOCAL' and len(self.state.environments) > 1:
        self.state.environments.pop()
        self.state.delayexpands.pop()
    elif verb == 'EXIT':
        it = iter(command.args)
        exit = True
        token = 0
        for arg in it:
            if arg.upper() == '/B':
                exit = False
                continue
            token = arg
            break
        try:
            code = int(token)
        except ValueError:
            code = 0
        raise Exit(code, exit)
    elif verb == 'CD' or verb == 'CHDIR':
        self.state.cwd = command.argument_string
    elif verb == 'PUSHD':
        directory = command.argument_string
        self.state.dirstack.append(self.cwd)
        self.cwd = directory.rstrip()
    elif verb == 'POPD':
        try:
            self.state.cwd = self.state.dirstack.pop()
        except IndexError:
            pass
    elif verb == 'ECHO':
        for io in command.redirects:
            if io.type == Redirect.In:
                continue
            if isinstance(path := io.target, str):
                path = unquote(path.lstrip())
                method = (
                    self.state.append_file
                ) if io.type == Redirect.OutAppend else (
                    self.state.create_file
                )
                method(path, command.argument_string)
            break
        else:
            yield str(command)
    else:
        yield str(command)
def emulate_pipeline(self, pipeline)
Expand source code Browse git
@_register(AstPipeline)
def emulate_pipeline(self, pipeline: AstPipeline):
    for part in pipeline.parts:
        yield from self.execute_command(part)
def emulate_sequence(self, sequence)
Expand source code Browse git
@_register(AstSequence)
def emulate_sequence(self, sequence: AstSequence):
    yield from self.emulate_statement(sequence.head)
    for cs in sequence.tail:
        if cs.condition == AstCondition.Failure:
            if self.state.ec == 0:
                continue
        if cs.condition == AstCondition.Success:
            if self.state.ec != 0:
                continue
        yield from self.emulate_statement(cs.statement)
def emulate_if(self, _if)
Expand source code Browse git
@_register(AstIf)
def emulate_if(self, _if: AstIf):
    if _if.variant == AstIfVariant.ErrorLevel:
        condition = _if.var_int <= self.state.ec
    elif _if.variant == AstIfVariant.CmdExtVersion:
        condition = _if.var_int <= self.state.extensions_version
    elif _if.variant == AstIfVariant.Exist:
        condition = self.state.exists_file(_if.var_str)
    elif _if.variant == AstIfVariant.Defined:
        condition = _if.var_str.upper() in self.state.environment
    else:
        lhs = _if.lhs
        rhs = _if.rhs
        cmp = _if.cmp
        assert lhs is not None
        assert rhs is not None
        if cmp == AstIfCmp.STR:
            if _if.casefold:
                if isinstance(lhs, str):
                    lhs = lhs.casefold()
                if isinstance(rhs, str):
                    rhs = rhs.casefold()
            condition = lhs == rhs
        elif cmp == AstIfCmp.GTR:
            condition = lhs > rhs
        elif cmp == AstIfCmp.GEQ:
            condition = lhs >= rhs
        elif cmp == AstIfCmp.NEQ:
            condition = lhs != rhs
        elif cmp == AstIfCmp.EQU:
            condition = lhs == rhs
        elif cmp == AstIfCmp.LSS:
            condition = lhs < rhs
        elif cmp == AstIfCmp.LEQ:
            condition = lhs <= rhs
        else:
            raise RuntimeError(cmp)
    if _if.negated:
        condition = not condition

    if condition:
        yield from self.emulate_statement(_if.then_do)
    elif (_else := _if.else_do):
        yield from self.emulate_statement(_else)
def emulate_for(self, _for)
Expand source code Browse git
@_register(AstFor)
def emulate_for(self, _for: AstFor):
    yield from ()
def emulate_group(self, group)
Expand source code Browse git
@_register(AstGroup)
def emulate_group(self, group: AstGroup):
    for sequence in group.sequences:
        yield from self.emulate_sequence(sequence)
def emulate_label(self, label)
Expand source code Browse git
@_register(AstLabel)
def emulate_label(self, label: AstLabel):
    yield from ()
def emulate_statement(self, statement)
Expand source code Browse git
def emulate_statement(self, statement: AstStatement):
    try:
        handler = self._register.handlers[statement.__class__]
    except KeyError:
        raise RuntimeError(statement)
    yield from handler(self, statement)
def emulate(self, offset=0, name=None, command_line='', called=False)
Expand source code Browse git
def emulate(self, offset: int = 0, name: str | None = None, command_line: str = '', called: bool = False):
    if name:
        self.state.name = name
    self.state.command_line = command_line
    length = len(self.parser.lexer.code)
    labels = self.parser.lexer.labels

    while offset < length:
        try:
            for sequence in self.parser.parse(offset):
                yield from self.emulate_sequence(sequence)
        except Goto as goto:
            try:
                offset = labels[goto.label.upper()]
            except KeyError:
                raise InvalidLabel(goto.label) from goto
            continue
        except Exit as exit:
            self.state.ec = exit.code
            if exit.exit and called:
                raise
            else:
                break
        else:
            break
class BatchLexer (data, state=None)
Expand source code Browse git
class BatchLexer:

    labels: dict[str, int]
    code: memoryview

    var_cmdarg: ArgVar | None
    var_resume: int
    var_offset: int
    var_dollar: int

    pending_redirect: RedirectIO | None

    cursor: BatchLexerCursor
    resume: BatchLexerCursor | None

    class _register:
        # A handler is given the current mode and char. It returns a boolean indicating
        # whether or not the character was processed and may be consumed.
        handlers: ClassVar[dict[Mode, Callable[
            [BatchLexer, Mode, int], Generator[str | Ctrl, None, bool]
        ]]] = {}

        def __init__(self, *modes: Mode):
            self.modes = modes

        def __call__(self, handler):
            for mode in self.modes:
                self.handlers[mode] = handler
            return handler

    def __init__(self, data: str | buf | BatchLexer, state: BatchState | None = None):
        if isinstance(data, BatchLexer):
            if state is not None:
                raise NotImplementedError
            self.code = data.code
            self.labels = data.labels
            self.state = data.state
        else:
            if state is None:
                state = BatchState()
            self.state = state
            self.preparse(data)

    def parse_label_abort(self):
        self.mode_finish()

    def parse_group(self):
        self.group += 1

    def parse_label(self):
        if (m := self.mode) != Mode.Text or len(self.modes) != 1:
            raise EmulatorException(F'Switching to LABEL while in mode {m.name}')
        self.mode_switch(Mode.Label)

    def parse_set(self):
        if (m := self.mode) != Mode.Text or len(self.modes) != 1:
            raise EmulatorException(F'Switching to SET while in mode {m.name}')
        self.mode_switch(Mode.SetStarted)

    @property
    def environment(self):
        return self.state.environment

    def parse_arg_variable(self, var: ArgVar):
        """
        %* in a batch script refers to all the arguments (e.g. %1 %2 %3
            %4 %5 ...)
        Substitution of batch parameters (%n) has been enhanced.  You can
        now use the following optional syntax:
            %~1         - expands %1 removing any surrounding quotes (")
            %~f1        - expands %1 to a fully qualified path name
            %~d1        - expands %1 to a drive letter only
            %~p1        - expands %1 to a path only
            %~n1        - expands %1 to a file name only
            %~x1        - expands %1 to a file extension only
            %~s1        - expanded path contains short names only
            %~a1        - expands %1 to file attributes
            %~t1        - expands %1 to date/time of file
            %~z1        - expands %1 to size of file
            %~$PATH:1   - searches the directories listed in the PATH
                           environment variable and expands %1 to the fully
                           qualified name of the first one found.  If the
                           environment variable name is not defined or the
                           file is not found by the search, then this
                           modifier expands to the empty string
        The modifiers can be combined to get compound results:
            %~dp1       - expands %1 to a drive letter and path only
            %~nx1       - expands %1 to a file name and extension only
            %~dp$PATH:1 - searches the directories listed in the PATH
                           environment variable for %1 and expands to the
                           drive letter and path of the first one found.
            %~ftza1     - expands %1 to a DIR like output line
        In the above examples %1 and PATH can be replaced by other
        valid values.  The %~ syntax is terminated by a valid argument
        number.  The %~ modifiers may not be used with %*
        """
        state = self.state

        if (k := var.offset) is (...):
            return state.command_line
        if (j := k - 1) < 0:
            argval = state.name
        elif j < len(args := state.args):
            argval = args[j]
        else:
            return ''

        if var.flags.StripQuotes and argval.startswith('"') and argval.endswith('"'):
            argval = argval[1:-1]
        with io.StringIO() as output:
            if var.flags.StripQuotes:
                ...
            if var.flags.FullPath:
                ...
            if var.flags.DriveLetter:
                ...
            if var.flags.PathOnly:
                ...
            if var.flags.NameOnly:
                ...
            if var.flags.Extension:
                ...
            if var.flags.ShortName:
                ...
            if var.flags.Attributes:
                ...
            if var.flags.DateTime:
                ...
            if var.flags.FileSize:
                ...
            output.write(argval)
            return output.getvalue()

    def reset(self, offset: int):
        self.modes = [Mode.Text]
        self.quote = False
        self.caret = False
        self.white = False
        self.first = True
        self.group = 0
        self.cursor = BatchLexerCursor(offset)
        self.resume = None
        self.var_resume = -1
        self.var_offset = -1
        self.var_dollar = -1
        self.var_cmdarg = None
        self.pending_redirect = None

    def mode_finish(self):
        modes = self.modes
        if len(modes) <= 1:
            raise RuntimeError('Trying to exit base mode.')
        self.modes.pop()

    def mode_switch(self, mode: Mode):
        self.modes.append(mode)

    @property
    def mode(self):
        return self.modes[-1]

    @mode.setter
    def mode(self, value: Mode):
        self.modes[-1] = value

    @property
    def substituting(self):
        return self.cursor.substituting

    @property
    def eof(self):
        return (c := self.cursor).offset >= len(self.code) and not c.subst_buffer

    def quick_save(self):
        self.resume = self.cursor.copy()

    def quick_load(self):
        if (resume := self.resume) is None:
            raise RuntimeError
        self.cursor = resume
        self.resume = None

    def current_char(self, lookahead=0):
        if not (subst := self.cursor.subst_buffer):
            offset = self.cursor.offset + lookahead
        else:
            offset = self.cursor.subst_offset
            if lookahead:
                offset += lookahead
            if offset >= (n := len(subst)):
                offset -= n
                offset += self.cursor.offset
            else:
                return self.cursor.subst_buffer[offset]
        try:
            return self.code[offset]
        except IndexError:
            raise UnexpectedEOF

    def consume_char(self):
        if subst := self.cursor.subst_buffer:
            offset = self.cursor.subst_offset + 1
            if offset >= len(subst):
                del subst[:]
                self.cursor.subst_offset = -1
            else:
                self.cursor.subst_offset = offset
        else:
            offset = self.cursor.offset + 1
            if offset > len(self.code):
                raise EOFError('Consumed a character beyond EOF.')
            self.cursor.offset = offset

    def peek_char(self):
        try:
            return self.current_char(1)
        except UnexpectedEOF:
            return None

    def next_char(self):
        self.consume_char()
        return self.current_char()

    def parse_env_variable(self, var: str):
        name, _, modifier = var.partition(':')
        base = self.state.envar(name)
        if not modifier or not base:
            return base
        if '=' in modifier:
            old, _, new = modifier.partition('=')
            kwargs = {}
            if old.startswith('~'):
                old = old[1:]
                kwargs.update(count=1)
            return base.replace(old, new, **kwargs)
        else:
            if not modifier.startswith('~'):
                raise EmulatorException
            offset, _, length = modifier[1:].partition(',')
            offset = batchint(offset)
            if offset < 0:
                offset = max(0, len(base) + offset)
            if length:
                end = offset + batchint(length)
            else:
                end = len(base)
            return base[offset:end]

    def emit_token(self):
        if (buffer := self.cursor.token) and (token := u16(buffer)):
            if (io := self.pending_redirect):
                if self.mode != Mode.RedirectIO:
                    raise RuntimeError
                self.mode_finish()
                self.pending_redirect = None
                io.target, token = token, io
            yield token
        del buffer[:]
        self.first = False

    def line_break(self):
        self.first = True
        self.white = True
        self.quote = False
        yield Ctrl.NewLine

    def tokens(self, offset: int):
        self.reset(offset)
        handlers = self._register.handlers
        current_char = self.current_char
        consume_char = self.consume_char
        size = len(self.code)

        while self.cursor.offset < size:
            c = current_char()
            m = self.mode
            h = handlers[m]
            if (yield from h(self, m, c)):
                consume_char()

        yield from self.emit_token()

    def check_variable_start(self, char: int):
        if char != PERCENT:
            return False
        if self.cursor.substituting:
            return False
        if self.next_char() == PERCENT:
            self.consume_char()
            self.cursor.token.append(PERCENT)
            return True
        self.mode_switch(Mode.VarStarted)
        self.var_cmdarg = ArgVar()
        self.var_offset = self.cursor.offset
        return True

    def check_line_break(self, mode: Mode, char: int):
        if char != LINEBREAK:
            return False
        if not self.caret:
            # caret is not reset until the next char!
            yield from self.emit_token()
            yield from self.line_break()
            del self.modes[1:]
        self.consume_char()
        return True

    def check_command_separators(self, mode: Mode, char: int):
        if char == PAREN_CLOSE and (g := self.group) > 0:
            yield from self.emit_token()
            yield Ctrl.EndGroup
            self.consume_char()
            self.group = g - 1
            return True
        try:
            one, two = SeparatorEscalation[char]
        except KeyError:
            return False
        if self.first:
            raise UnexpectedFirstToken(char)
        if mode != Mode.Text:
            self.mode_finish()
        yield from self.emit_token()
        if self.next_char() == char:
            self.consume_char()
            yield two
        else:
            yield one
        self.first = False
        return True

    def check_quote_start(self, char: int):
        if char != QUOTE:
            return False
        self.cursor.token.append(char)
        self.mode_switch(Mode.Quote)
        self.caret = False
        self.first = False
        self.consume_char()
        return True

    def check_redirect_io(self, char: int):
        if char not in ANGLES:
            return False

        output = char != ANGLE_OPEN
        token = self.cursor.token

        if len(token) == 1 and (src := token[0] - ZERO) in range(10):
            del token[:]
            source = src
        else:
            source = int(output)

        char = self.next_char()

        if not output:
            how = Redirect.In
        elif char == ANGLE_CLOSE:
            how = Redirect.OutAppend
            char = self.next_char()
        else:
            how = Redirect.Out

        yield from self.emit_token()

        if char != AMPERSAND:
            self.pending_redirect = RedirectIO(how, source)
            self.mode_switch(Mode.RedirectIO)
        else:
            char = self.next_char()
            if char not in range(ZERO, NINE + 1):
                raise UnexpectedToken(char)
            self.consume_char()
            yield RedirectIO(how, source, char - ZERO)

        return True

    @_register(
        Mode.VarStarted,
        Mode.VarDollar,
        Mode.VarColon,
    )
    def gobble_var(self, mode: Mode, char: int):
        yield from ()

        def done():
            self.mode_finish()
            self.var_cmdarg = None
            self.var_resume = -1
            self.var_offset = -1
            return False

        var_offset = self.var_offset
        var_resume = self.var_resume
        var_cmdarg = self.var_cmdarg
        current = self.cursor.offset
        variable = None

        if self.substituting:
            raise RuntimeError('Nested variable substitution.')

        if char == LINEBREAK:
            if var_resume < 0:
                var_resume = var_offset
            if var_resume < 0:
                raise RuntimeError
            self.cursor.offset = var_resume
            return done()

        if char == PERCENT:
            var_name = u16(self.code[var_offset:self.cursor.offset])
            variable = u16(self.parse_env_variable(var_name))
        elif var_cmdarg:
            if ZERO <= char <= NINE:
                var_cmdarg.offset = char - ZERO
                variable = u16(self.parse_arg_variable(var_cmdarg))
            elif char == ASTERIX and var_offset == current:
                var_cmdarg.offset = (...)
                variable = u16(self.parse_arg_variable(var_cmdarg))

        if variable is not None:
            self.consume_char()
            self.cursor.subst_buffer.extend(variable)
            self.cursor.subst_offset = 0
            return done()

        if mode == Mode.VarColon:
            # With a colon, the argument index must follow immediately: %~$PATH:0
            # If there is anything between colon and digit, it is not an argument variable.
            self.var_cmdarg = None
        if mode == Mode.VarDollar:
            if char == COLON:
                if var_cmdarg:
                    var_cmdarg.path = u16(self.code[self.var_dollar:current])
                self.var_resume = current
        if mode == Mode.VarStarted:
            if char == DOLLAR:
                self.var_dollar = current
                self.mode = Mode.VarDollar
                return True
            if char == COLON:
                self.var_cmdarg = None
                self.mode = Mode.VarColon
                self.var_resume = current
                return True
            if not var_cmdarg:
                return True
            try:
                flag = ArgVarFlags.FromToken(char)
            except KeyError:
                self.var_cmdarg = None
                return True
            if flag == ArgVarFlags.StripQuotes and var_cmdarg.flags > 0:
                self.var_cmdarg = None
            elif ArgVarFlags.StripQuotes not in var_cmdarg.flags:
                self.var_cmdarg = None
        return True

    @_register(Mode.Label)
    def gobble_label(self, mode: Mode, char: int):
        if (yield from self.check_line_break(mode, char)):
            return False
        self.cursor.token.append(char)
        return True

    @_register(Mode.Quote)
    def gobble_quote(self, mode: Mode, char: int):
        if (yield from self.check_line_break(mode, char)):
            return False
        if self.check_variable_start(char):
            return False
        self.cursor.token.append(char)
        if char == QUOTE:
            self.mode_finish()
        return True

    @_register(Mode.Whitespace)
    def gobble_whitespace(self, mode: Mode, char: int):
        if char in WHITESPACE:
            self.cursor.token.append(char)
            return True
        yield from self.emit_token()
        self.mode_finish()
        return False

    @_register(Mode.SetQuoted)
    def gobble_quoted_set(self, mode: Mode, char: int):
        if char == QUOTE:
            self.consume_char()
            self.cursor.token.append(QUOTE)
            self.quick_save()
            return False

        if char == LINEBREAK:
            if self.resume is None:
                yield from self.emit_token()
                yield Ctrl.NewLine
            elif self.caret:
                self.caret = False
                return True
            else:
                self.quick_load()
                yield from self.emit_token()
            self.mode_finish()
            return True

        if char == CARET and self.resume:
            self.caret = not self.caret

        if char in (PIPE, AMPERSAND) and not self.caret and self.resume is not None:
            self.quick_load()
            yield from self.emit_token()
            self.mode_finish()
            # after a quick load, the ending quote was already consumed.
            return False

        self.cursor.token.append(char)
        return True

    @_register(
        Mode.Text,
        Mode.SetStarted,
        Mode.SetRegular,
        Mode.RedirectIO,
    )
    def gobble_txt(self, mode: Mode, char: int):
        token = self.cursor.token

        if (yield from self.check_line_break(mode, char)):
            return False

        if self.check_variable_start(char):
            return False

        if not token and char == QUOTE and mode == Mode.SetStarted:
            self.caret = False
            token.append(char)
            self.mode = Mode.SetQuoted
            return True

        if self.caret:
            token.append(char)
            self.caret = False
            self.consume_char()
            return False

        if char == EQUALS and mode == Mode.SetStarted:
            yield from self.emit_token()
            yield '='
            self.mode = Mode.SetRegular
            return True

        if self.check_quote_start(char):
            return False

        if char == CARET:
            self.caret = True
            self.first = False
            return True

        if char in WHITESPACE and mode == Mode.Text:
            yield from self.emit_token()
            token.append(char)
            self.mode_switch(Mode.Whitespace)
            return True

        if (yield from self.check_command_separators(mode, char)):
            return False

        if (yield from self.check_redirect_io(char)):
            return False

        if mode == Mode.Text:
            if char == PAREN_OPEN:
                self.first = False
                yield from self.emit_token()
                yield Ctrl.NewGroup
                return True
            if char == PAREN_CLOSE:
                self.first = False
                yield from self.emit_token()
                yield Ctrl.EndGroup
                return True
            if char == COLON:
                self.first = False
                yield from self.emit_token()
                yield Ctrl.Label
                return True
            if char == EQUALS:
                yield from self.emit_token()
                if self.next_char() != EQUALS:
                    yield Ctrl.Equals
                    return False
                else:
                    yield Ctrl.IsEqualTo
                    return True

        self.cursor.token.append(char)
        return True

    def _decode(self, data: buf):
        if data[:3] == B'\xEF\xBB\xBF':
            return codecs.decode(data[3:], 'utf8')
        elif data[:2] == B'\xFF\xFE':
            return codecs.decode(data[2:], 'utf-16le')
        elif data[:2] == B'\xFE\xFF':
            return codecs.decode(data[2:], 'utf-16be')
        else:
            return codecs.decode(data, 'cp1252')

    @staticmethod
    def label(text: str):
        parts = re.split('([\x20\t\v])', text.lstrip())
        for k, part in itertools.islice(enumerate(parts), 0, None, 2):
            tq, part = uncaret(part, True)
            if not tq:
                parts[k] = part
                del parts[k + 1:]
                break
            parts[k] = part[:-1]
        return ''.join(parts).upper()

    def preparse(self, text: str | buf):
        self.labels = {}

        if not isinstance(text, str):
            text = self._decode(text)

        lines = re.split(r'[\r\n]+', text.strip())
        utf16 = array.array('H')

        for line in lines:
            encoded = line.encode('utf-16le')
            encoded = memoryview(encoded).cast('H')
            offset = len(utf16)
            prefix = re.search('[^\\s]', line)
            if prefix and encoded[(p := prefix.start())] == COLON and (lb := self.label(u16(encoded[p + 1:]))):
                self.labels.setdefault(lb, offset + p)
            utf16.extend(encoded)
            utf16.append(LINEBREAK)

        self.code = memoryview(utf16)

    if set(_register.handlers) != set(Mode):
        raise NotImplementedError('Not all handlers were implemented.')

Class variables

var labels

The type of the None singleton.

var code

The type of the None singleton.

var var_cmdarg

The type of the None singleton.

var var_resume

The type of the None singleton.

var var_offset

The type of the None singleton.

var var_dollar

The type of the None singleton.

var pending_redirect

The type of the None singleton.

var cursor

The type of the None singleton.

var resume

The type of the None singleton.

Static methods

def label(text)
Expand source code Browse git
@staticmethod
def label(text: str):
    parts = re.split('([\x20\t\v])', text.lstrip())
    for k, part in itertools.islice(enumerate(parts), 0, None, 2):
        tq, part = uncaret(part, True)
        if not tq:
            parts[k] = part
            del parts[k + 1:]
            break
        parts[k] = part[:-1]
    return ''.join(parts).upper()

Instance variables

var environment
Expand source code Browse git
@property
def environment(self):
    return self.state.environment
var mode
Expand source code Browse git
@property
def mode(self):
    return self.modes[-1]
var substituting
Expand source code Browse git
@property
def substituting(self):
    return self.cursor.substituting
var eof
Expand source code Browse git
@property
def eof(self):
    return (c := self.cursor).offset >= len(self.code) and not c.subst_buffer

Methods

def parse_label_abort(self)
Expand source code Browse git
def parse_label_abort(self):
    self.mode_finish()
def parse_group(self)
Expand source code Browse git
def parse_group(self):
    self.group += 1
def parse_label(self)
Expand source code Browse git
def parse_label(self):
    if (m := self.mode) != Mode.Text or len(self.modes) != 1:
        raise EmulatorException(F'Switching to LABEL while in mode {m.name}')
    self.mode_switch(Mode.Label)
def parse_set(self)
Expand source code Browse git
def parse_set(self):
    if (m := self.mode) != Mode.Text or len(self.modes) != 1:
        raise EmulatorException(F'Switching to SET while in mode {m.name}')
    self.mode_switch(Mode.SetStarted)
def parse_arg_variable(self, var)

% in a batch script refers to all the arguments (e.g. %1 %2 %3 %4 %5 …) Substitution of batch parameters (%n) has been enhanced. You can now use the following optional syntax: %~1 - expands %1 removing any surrounding quotes (") %~f1 - expands %1 to a fully qualified path name %~d1 - expands %1 to a drive letter only %~p1 - expands %1 to a path only %~n1 - expands %1 to a file name only %~x1 - expands %1 to a file extension only %~s1 - expanded path contains short names only %~a1 - expands %1 to file attributes %~t1 - expands %1 to date/time of file %~z1 - expands %1 to size of file %~$PATH:1 - searches the directories listed in the PATH environment variable and expands %1 to the fully qualified name of the first one found. If the environment variable name is not defined or the file is not found by the search, then this modifier expands to the empty string The modifiers can be combined to get compound results: %~dp1 - expands %1 to a drive letter and path only %~nx1 - expands %1 to a file name and extension only %~dp$PATH:1 - searches the directories listed in the PATH environment variable for %1 and expands to the drive letter and path of the first one found. %~ftza1 - expands %1 to a DIR like output line In the above examples %1 and PATH can be replaced by other valid values. The %~ syntax is terminated by a valid argument number. The %~ modifiers may not be used with %

Expand source code Browse git
def parse_arg_variable(self, var: ArgVar):
    """
    %* in a batch script refers to all the arguments (e.g. %1 %2 %3
        %4 %5 ...)
    Substitution of batch parameters (%n) has been enhanced.  You can
    now use the following optional syntax:
        %~1         - expands %1 removing any surrounding quotes (")
        %~f1        - expands %1 to a fully qualified path name
        %~d1        - expands %1 to a drive letter only
        %~p1        - expands %1 to a path only
        %~n1        - expands %1 to a file name only
        %~x1        - expands %1 to a file extension only
        %~s1        - expanded path contains short names only
        %~a1        - expands %1 to file attributes
        %~t1        - expands %1 to date/time of file
        %~z1        - expands %1 to size of file
        %~$PATH:1   - searches the directories listed in the PATH
                       environment variable and expands %1 to the fully
                       qualified name of the first one found.  If the
                       environment variable name is not defined or the
                       file is not found by the search, then this
                       modifier expands to the empty string
    The modifiers can be combined to get compound results:
        %~dp1       - expands %1 to a drive letter and path only
        %~nx1       - expands %1 to a file name and extension only
        %~dp$PATH:1 - searches the directories listed in the PATH
                       environment variable for %1 and expands to the
                       drive letter and path of the first one found.
        %~ftza1     - expands %1 to a DIR like output line
    In the above examples %1 and PATH can be replaced by other
    valid values.  The %~ syntax is terminated by a valid argument
    number.  The %~ modifiers may not be used with %*
    """
    state = self.state

    if (k := var.offset) is (...):
        return state.command_line
    if (j := k - 1) < 0:
        argval = state.name
    elif j < len(args := state.args):
        argval = args[j]
    else:
        return ''

    if var.flags.StripQuotes and argval.startswith('"') and argval.endswith('"'):
        argval = argval[1:-1]
    with io.StringIO() as output:
        if var.flags.StripQuotes:
            ...
        if var.flags.FullPath:
            ...
        if var.flags.DriveLetter:
            ...
        if var.flags.PathOnly:
            ...
        if var.flags.NameOnly:
            ...
        if var.flags.Extension:
            ...
        if var.flags.ShortName:
            ...
        if var.flags.Attributes:
            ...
        if var.flags.DateTime:
            ...
        if var.flags.FileSize:
            ...
        output.write(argval)
        return output.getvalue()
def reset(self, offset)
Expand source code Browse git
def reset(self, offset: int):
    self.modes = [Mode.Text]
    self.quote = False
    self.caret = False
    self.white = False
    self.first = True
    self.group = 0
    self.cursor = BatchLexerCursor(offset)
    self.resume = None
    self.var_resume = -1
    self.var_offset = -1
    self.var_dollar = -1
    self.var_cmdarg = None
    self.pending_redirect = None
def mode_finish(self)
Expand source code Browse git
def mode_finish(self):
    modes = self.modes
    if len(modes) <= 1:
        raise RuntimeError('Trying to exit base mode.')
    self.modes.pop()
def mode_switch(self, mode)
Expand source code Browse git
def mode_switch(self, mode: Mode):
    self.modes.append(mode)
def quick_save(self)
Expand source code Browse git
def quick_save(self):
    self.resume = self.cursor.copy()
def quick_load(self)
Expand source code Browse git
def quick_load(self):
    if (resume := self.resume) is None:
        raise RuntimeError
    self.cursor = resume
    self.resume = None
def current_char(self, lookahead=0)
Expand source code Browse git
def current_char(self, lookahead=0):
    if not (subst := self.cursor.subst_buffer):
        offset = self.cursor.offset + lookahead
    else:
        offset = self.cursor.subst_offset
        if lookahead:
            offset += lookahead
        if offset >= (n := len(subst)):
            offset -= n
            offset += self.cursor.offset
        else:
            return self.cursor.subst_buffer[offset]
    try:
        return self.code[offset]
    except IndexError:
        raise UnexpectedEOF
def consume_char(self)
Expand source code Browse git
def consume_char(self):
    if subst := self.cursor.subst_buffer:
        offset = self.cursor.subst_offset + 1
        if offset >= len(subst):
            del subst[:]
            self.cursor.subst_offset = -1
        else:
            self.cursor.subst_offset = offset
    else:
        offset = self.cursor.offset + 1
        if offset > len(self.code):
            raise EOFError('Consumed a character beyond EOF.')
        self.cursor.offset = offset
def peek_char(self)
Expand source code Browse git
def peek_char(self):
    try:
        return self.current_char(1)
    except UnexpectedEOF:
        return None
def next_char(self)
Expand source code Browse git
def next_char(self):
    self.consume_char()
    return self.current_char()
def parse_env_variable(self, var)
Expand source code Browse git
def parse_env_variable(self, var: str):
    name, _, modifier = var.partition(':')
    base = self.state.envar(name)
    if not modifier or not base:
        return base
    if '=' in modifier:
        old, _, new = modifier.partition('=')
        kwargs = {}
        if old.startswith('~'):
            old = old[1:]
            kwargs.update(count=1)
        return base.replace(old, new, **kwargs)
    else:
        if not modifier.startswith('~'):
            raise EmulatorException
        offset, _, length = modifier[1:].partition(',')
        offset = batchint(offset)
        if offset < 0:
            offset = max(0, len(base) + offset)
        if length:
            end = offset + batchint(length)
        else:
            end = len(base)
        return base[offset:end]
def emit_token(self)
Expand source code Browse git
def emit_token(self):
    if (buffer := self.cursor.token) and (token := u16(buffer)):
        if (io := self.pending_redirect):
            if self.mode != Mode.RedirectIO:
                raise RuntimeError
            self.mode_finish()
            self.pending_redirect = None
            io.target, token = token, io
        yield token
    del buffer[:]
    self.first = False
def line_break(self)
Expand source code Browse git
def line_break(self):
    self.first = True
    self.white = True
    self.quote = False
    yield Ctrl.NewLine
def tokens(self, offset)
Expand source code Browse git
def tokens(self, offset: int):
    self.reset(offset)
    handlers = self._register.handlers
    current_char = self.current_char
    consume_char = self.consume_char
    size = len(self.code)

    while self.cursor.offset < size:
        c = current_char()
        m = self.mode
        h = handlers[m]
        if (yield from h(self, m, c)):
            consume_char()

    yield from self.emit_token()
def check_variable_start(self, char)
Expand source code Browse git
def check_variable_start(self, char: int):
    if char != PERCENT:
        return False
    if self.cursor.substituting:
        return False
    if self.next_char() == PERCENT:
        self.consume_char()
        self.cursor.token.append(PERCENT)
        return True
    self.mode_switch(Mode.VarStarted)
    self.var_cmdarg = ArgVar()
    self.var_offset = self.cursor.offset
    return True
def check_line_break(self, mode, char)
Expand source code Browse git
def check_line_break(self, mode: Mode, char: int):
    if char != LINEBREAK:
        return False
    if not self.caret:
        # caret is not reset until the next char!
        yield from self.emit_token()
        yield from self.line_break()
        del self.modes[1:]
    self.consume_char()
    return True
def check_command_separators(self, mode, char)
Expand source code Browse git
def check_command_separators(self, mode: Mode, char: int):
    if char == PAREN_CLOSE and (g := self.group) > 0:
        yield from self.emit_token()
        yield Ctrl.EndGroup
        self.consume_char()
        self.group = g - 1
        return True
    try:
        one, two = SeparatorEscalation[char]
    except KeyError:
        return False
    if self.first:
        raise UnexpectedFirstToken(char)
    if mode != Mode.Text:
        self.mode_finish()
    yield from self.emit_token()
    if self.next_char() == char:
        self.consume_char()
        yield two
    else:
        yield one
    self.first = False
    return True
def check_quote_start(self, char)
Expand source code Browse git
def check_quote_start(self, char: int):
    if char != QUOTE:
        return False
    self.cursor.token.append(char)
    self.mode_switch(Mode.Quote)
    self.caret = False
    self.first = False
    self.consume_char()
    return True
def check_redirect_io(self, char)
Expand source code Browse git
def check_redirect_io(self, char: int):
    if char not in ANGLES:
        return False

    output = char != ANGLE_OPEN
    token = self.cursor.token

    if len(token) == 1 and (src := token[0] - ZERO) in range(10):
        del token[:]
        source = src
    else:
        source = int(output)

    char = self.next_char()

    if not output:
        how = Redirect.In
    elif char == ANGLE_CLOSE:
        how = Redirect.OutAppend
        char = self.next_char()
    else:
        how = Redirect.Out

    yield from self.emit_token()

    if char != AMPERSAND:
        self.pending_redirect = RedirectIO(how, source)
        self.mode_switch(Mode.RedirectIO)
    else:
        char = self.next_char()
        if char not in range(ZERO, NINE + 1):
            raise UnexpectedToken(char)
        self.consume_char()
        yield RedirectIO(how, source, char - ZERO)

    return True
def gobble_var(self, mode, char)
Expand source code Browse git
@_register(
    Mode.VarStarted,
    Mode.VarDollar,
    Mode.VarColon,
)
def gobble_var(self, mode: Mode, char: int):
    yield from ()

    def done():
        self.mode_finish()
        self.var_cmdarg = None
        self.var_resume = -1
        self.var_offset = -1
        return False

    var_offset = self.var_offset
    var_resume = self.var_resume
    var_cmdarg = self.var_cmdarg
    current = self.cursor.offset
    variable = None

    if self.substituting:
        raise RuntimeError('Nested variable substitution.')

    if char == LINEBREAK:
        if var_resume < 0:
            var_resume = var_offset
        if var_resume < 0:
            raise RuntimeError
        self.cursor.offset = var_resume
        return done()

    if char == PERCENT:
        var_name = u16(self.code[var_offset:self.cursor.offset])
        variable = u16(self.parse_env_variable(var_name))
    elif var_cmdarg:
        if ZERO <= char <= NINE:
            var_cmdarg.offset = char - ZERO
            variable = u16(self.parse_arg_variable(var_cmdarg))
        elif char == ASTERIX and var_offset == current:
            var_cmdarg.offset = (...)
            variable = u16(self.parse_arg_variable(var_cmdarg))

    if variable is not None:
        self.consume_char()
        self.cursor.subst_buffer.extend(variable)
        self.cursor.subst_offset = 0
        return done()

    if mode == Mode.VarColon:
        # With a colon, the argument index must follow immediately: %~$PATH:0
        # If there is anything between colon and digit, it is not an argument variable.
        self.var_cmdarg = None
    if mode == Mode.VarDollar:
        if char == COLON:
            if var_cmdarg:
                var_cmdarg.path = u16(self.code[self.var_dollar:current])
            self.var_resume = current
    if mode == Mode.VarStarted:
        if char == DOLLAR:
            self.var_dollar = current
            self.mode = Mode.VarDollar
            return True
        if char == COLON:
            self.var_cmdarg = None
            self.mode = Mode.VarColon
            self.var_resume = current
            return True
        if not var_cmdarg:
            return True
        try:
            flag = ArgVarFlags.FromToken(char)
        except KeyError:
            self.var_cmdarg = None
            return True
        if flag == ArgVarFlags.StripQuotes and var_cmdarg.flags > 0:
            self.var_cmdarg = None
        elif ArgVarFlags.StripQuotes not in var_cmdarg.flags:
            self.var_cmdarg = None
    return True
def gobble_label(self, mode, char)
Expand source code Browse git
@_register(Mode.Label)
def gobble_label(self, mode: Mode, char: int):
    if (yield from self.check_line_break(mode, char)):
        return False
    self.cursor.token.append(char)
    return True
def gobble_quote(self, mode, char)
Expand source code Browse git
@_register(Mode.Quote)
def gobble_quote(self, mode: Mode, char: int):
    if (yield from self.check_line_break(mode, char)):
        return False
    if self.check_variable_start(char):
        return False
    self.cursor.token.append(char)
    if char == QUOTE:
        self.mode_finish()
    return True
def gobble_whitespace(self, mode, char)
Expand source code Browse git
@_register(Mode.Whitespace)
def gobble_whitespace(self, mode: Mode, char: int):
    if char in WHITESPACE:
        self.cursor.token.append(char)
        return True
    yield from self.emit_token()
    self.mode_finish()
    return False
def gobble_quoted_set(self, mode, char)
Expand source code Browse git
@_register(Mode.SetQuoted)
def gobble_quoted_set(self, mode: Mode, char: int):
    if char == QUOTE:
        self.consume_char()
        self.cursor.token.append(QUOTE)
        self.quick_save()
        return False

    if char == LINEBREAK:
        if self.resume is None:
            yield from self.emit_token()
            yield Ctrl.NewLine
        elif self.caret:
            self.caret = False
            return True
        else:
            self.quick_load()
            yield from self.emit_token()
        self.mode_finish()
        return True

    if char == CARET and self.resume:
        self.caret = not self.caret

    if char in (PIPE, AMPERSAND) and not self.caret and self.resume is not None:
        self.quick_load()
        yield from self.emit_token()
        self.mode_finish()
        # after a quick load, the ending quote was already consumed.
        return False

    self.cursor.token.append(char)
    return True
def gobble_txt(self, mode, char)
Expand source code Browse git
@_register(
    Mode.Text,
    Mode.SetStarted,
    Mode.SetRegular,
    Mode.RedirectIO,
)
def gobble_txt(self, mode: Mode, char: int):
    token = self.cursor.token

    if (yield from self.check_line_break(mode, char)):
        return False

    if self.check_variable_start(char):
        return False

    if not token and char == QUOTE and mode == Mode.SetStarted:
        self.caret = False
        token.append(char)
        self.mode = Mode.SetQuoted
        return True

    if self.caret:
        token.append(char)
        self.caret = False
        self.consume_char()
        return False

    if char == EQUALS and mode == Mode.SetStarted:
        yield from self.emit_token()
        yield '='
        self.mode = Mode.SetRegular
        return True

    if self.check_quote_start(char):
        return False

    if char == CARET:
        self.caret = True
        self.first = False
        return True

    if char in WHITESPACE and mode == Mode.Text:
        yield from self.emit_token()
        token.append(char)
        self.mode_switch(Mode.Whitespace)
        return True

    if (yield from self.check_command_separators(mode, char)):
        return False

    if (yield from self.check_redirect_io(char)):
        return False

    if mode == Mode.Text:
        if char == PAREN_OPEN:
            self.first = False
            yield from self.emit_token()
            yield Ctrl.NewGroup
            return True
        if char == PAREN_CLOSE:
            self.first = False
            yield from self.emit_token()
            yield Ctrl.EndGroup
            return True
        if char == COLON:
            self.first = False
            yield from self.emit_token()
            yield Ctrl.Label
            return True
        if char == EQUALS:
            yield from self.emit_token()
            if self.next_char() != EQUALS:
                yield Ctrl.Equals
                return False
            else:
                yield Ctrl.IsEqualTo
                return True

    self.cursor.token.append(char)
    return True
def preparse(self, text)
Expand source code Browse git
def preparse(self, text: str | buf):
    self.labels = {}

    if not isinstance(text, str):
        text = self._decode(text)

    lines = re.split(r'[\r\n]+', text.strip())
    utf16 = array.array('H')

    for line in lines:
        encoded = line.encode('utf-16le')
        encoded = memoryview(encoded).cast('H')
        offset = len(utf16)
        prefix = re.search('[^\\s]', line)
        if prefix and encoded[(p := prefix.start())] == COLON and (lb := self.label(u16(encoded[p + 1:]))):
            self.labels.setdefault(lb, offset + p)
        utf16.extend(encoded)
        utf16.append(LINEBREAK)

    self.code = memoryview(utf16)
class BatchParser (data, state=None)
Expand source code Browse git
class BatchParser:

    def __init__(self, data: str | buf | BatchParser, state: BatchState | None = None):
        if isinstance(data, BatchParser):
            if state is not None:
                raise NotImplementedError
            src = data.lexer
        else:
            src = data
        self.lexer = BatchLexer(src, state)

    @property
    def state(self):
        return self.lexer.state

    @property
    def environment(self):
        return self.state.environment

    def command(self, tokens: LookAhead, in_group: bool) -> AstCommand | None:
        ast = AstCommand(tokens.offset())
        tok = tokens.peek()
        cmd = ast.tokens
        if tok.upper() == 'SET':
            self.lexer.parse_set()
        while tok not in (
            Ctrl.CommandSeparator,
            Ctrl.RunOnFailure,
            Ctrl.RunOnSuccess,
            Ctrl.Pipe,
            Ctrl.NewLine,
            Ctrl.EndOfFile,
        ):
            if in_group and tok == Ctrl.EndGroup:
                break
            cmd.append(tok)
            tokens.pop()
            tok = tokens.peek()
        if ast.tokens:
            return ast

    def pipeline(self, tokens: LookAhead, in_group: bool) -> AstPipeline | None:
        if head := self.command(tokens, in_group):
            node = AstPipeline(head.offset, [head])
            while tokens.pop(Ctrl.Pipe):
                if cmd := self.command(tokens, in_group):
                    node.parts.append(cmd)
                    continue
                raise UnexpectedToken(tokens.peek())
            return node

    def ifthen(self, tokens: LookAhead, in_group: bool) -> AstIf | None:
        offset = tokens.offset()

        if not tokens.pop_string('IF'):
            return None

        casefold = False
        negated = False
        lhs = None
        rhs = None

        token = tokens.word()

        if token.upper() == '/I':
            casefold = True
            token = tokens.word()
        if token.upper() == 'NOT':
            negated = True
            token = tokens.word()
        try:
            variant = AstIfVariant(token.upper())
        except Exception:
            tokens.skip_space()
            variant = None
            lhs = token
            cmp = next(tokens)
            try:
                cmp = AstIfCmp(cmp)
            except Exception:
                raise UnexpectedToken(cmp)
            if cmp != AstIfCmp.STR and self.state.extensions_version < 1:
                raise UnexpectedToken(cmp)

            rhs = tokens.consume_nonspace_words()

            try:
                ilh = batchint(lhs)
                irh = batchint(rhs)
            except ValueError:
                pass
            else:
                lhs = ilh
                rhs = irh
        else:
            lhs = unquote(tokens.consume_nonspace_words())
            rhs = None
            cmp = None
            try:
                lhs = batchint(lhs)
            except ValueError:
                pass

        then_do = self.sequence(tokens, in_group)

        if then_do is None:
            raise UnexpectedToken(tokens.peek())

        tokens.skip_space()

        if tokens.peek().upper() == 'ELSE':
            tokens.pop()
            else_do = self.sequence(tokens, in_group)
        else:
            else_do = None

        return AstIf(
            offset,
            then_do,
            else_do,
            variant,
            casefold,
            negated,
            cmp,
            lhs, rhs # type:ignore
        )

    def forloop(self, tokens: LookAhead, in_group: bool) -> AstFor | None:
        if not tokens.pop_string('FOR'):
            return None
        return None

    def group(self, tokens: LookAhead) -> AstGroup | None:
        offset = tokens.offset()
        if tokens.pop(Ctrl.NewGroup):
            self.lexer.parse_group()
            sequences: list[AstSequence] = []
            while not tokens.pop(Ctrl.EndGroup) and (sequence := self.sequence(tokens, True)):
                sequences.append(sequence)
            return AstGroup(offset, sequences)

    def label(self, tokens: LookAhead) -> AstLabel | None:
        offset = tokens.offset()
        lexer = self.lexer
        lexer.parse_label()
        if not tokens.pop(Ctrl.Label):
            lexer.parse_label_abort()
            return None
        line = tokens.word()
        label = lexer.label(line)
        if (x := lexer.labels[label]) != offset:
            raise RuntimeError(F'Expected offset for label {label} to be {offset}, got {x} instead.')
        return AstLabel(offset, line, label)

    def statement(self, tokens: LookAhead, in_group: bool):
        if s := self.label(tokens):
            return s
        if s := self.ifthen(tokens, in_group):
            return s
        if s := self.group(tokens):
            return s
        if s := self.forloop(tokens, in_group):
            return s
        return self.pipeline(tokens, in_group)

    def sequence(self, tokens: LookAhead, in_group: bool) -> AstSequence | None:
        tokens.skip_space()
        head = self.statement(tokens, in_group)
        if head is None:
            return None
        node = AstSequence(head.offset, head)
        tokens.skip_space()
        while condition := AstCondition.Try(tokens.peek()):
            tokens.pop()
            tokens.skip_space()
            if not (statement := self.statement(tokens, in_group)):
                raise EmulatorException('Failed to parse conditional statement.')
            node.tail.append(
                AstConditionalStatement(statement.offset, condition, statement))
            tokens.skip_space()
        return node

    def parse(self, offset: int):
        tokens = LookAhead(self.lexer, offset)
        while sequence := self.sequence(tokens, False):
            yield sequence

Instance variables

var state
Expand source code Browse git
@property
def state(self):
    return self.lexer.state
var environment
Expand source code Browse git
@property
def environment(self):
    return self.state.environment

Methods

def command(self, tokens, in_group)
Expand source code Browse git
def command(self, tokens: LookAhead, in_group: bool) -> AstCommand | None:
    ast = AstCommand(tokens.offset())
    tok = tokens.peek()
    cmd = ast.tokens
    if tok.upper() == 'SET':
        self.lexer.parse_set()
    while tok not in (
        Ctrl.CommandSeparator,
        Ctrl.RunOnFailure,
        Ctrl.RunOnSuccess,
        Ctrl.Pipe,
        Ctrl.NewLine,
        Ctrl.EndOfFile,
    ):
        if in_group and tok == Ctrl.EndGroup:
            break
        cmd.append(tok)
        tokens.pop()
        tok = tokens.peek()
    if ast.tokens:
        return ast
def pipeline(self, tokens, in_group)
Expand source code Browse git
def pipeline(self, tokens: LookAhead, in_group: bool) -> AstPipeline | None:
    if head := self.command(tokens, in_group):
        node = AstPipeline(head.offset, [head])
        while tokens.pop(Ctrl.Pipe):
            if cmd := self.command(tokens, in_group):
                node.parts.append(cmd)
                continue
            raise UnexpectedToken(tokens.peek())
        return node
def ifthen(self, tokens, in_group)
Expand source code Browse git
def ifthen(self, tokens: LookAhead, in_group: bool) -> AstIf | None:
    offset = tokens.offset()

    if not tokens.pop_string('IF'):
        return None

    casefold = False
    negated = False
    lhs = None
    rhs = None

    token = tokens.word()

    if token.upper() == '/I':
        casefold = True
        token = tokens.word()
    if token.upper() == 'NOT':
        negated = True
        token = tokens.word()
    try:
        variant = AstIfVariant(token.upper())
    except Exception:
        tokens.skip_space()
        variant = None
        lhs = token
        cmp = next(tokens)
        try:
            cmp = AstIfCmp(cmp)
        except Exception:
            raise UnexpectedToken(cmp)
        if cmp != AstIfCmp.STR and self.state.extensions_version < 1:
            raise UnexpectedToken(cmp)

        rhs = tokens.consume_nonspace_words()

        try:
            ilh = batchint(lhs)
            irh = batchint(rhs)
        except ValueError:
            pass
        else:
            lhs = ilh
            rhs = irh
    else:
        lhs = unquote(tokens.consume_nonspace_words())
        rhs = None
        cmp = None
        try:
            lhs = batchint(lhs)
        except ValueError:
            pass

    then_do = self.sequence(tokens, in_group)

    if then_do is None:
        raise UnexpectedToken(tokens.peek())

    tokens.skip_space()

    if tokens.peek().upper() == 'ELSE':
        tokens.pop()
        else_do = self.sequence(tokens, in_group)
    else:
        else_do = None

    return AstIf(
        offset,
        then_do,
        else_do,
        variant,
        casefold,
        negated,
        cmp,
        lhs, rhs # type:ignore
    )
def forloop(self, tokens, in_group)
Expand source code Browse git
def forloop(self, tokens: LookAhead, in_group: bool) -> AstFor | None:
    if not tokens.pop_string('FOR'):
        return None
    return None
def group(self, tokens)
Expand source code Browse git
def group(self, tokens: LookAhead) -> AstGroup | None:
    offset = tokens.offset()
    if tokens.pop(Ctrl.NewGroup):
        self.lexer.parse_group()
        sequences: list[AstSequence] = []
        while not tokens.pop(Ctrl.EndGroup) and (sequence := self.sequence(tokens, True)):
            sequences.append(sequence)
        return AstGroup(offset, sequences)
def label(self, tokens)
Expand source code Browse git
def label(self, tokens: LookAhead) -> AstLabel | None:
    offset = tokens.offset()
    lexer = self.lexer
    lexer.parse_label()
    if not tokens.pop(Ctrl.Label):
        lexer.parse_label_abort()
        return None
    line = tokens.word()
    label = lexer.label(line)
    if (x := lexer.labels[label]) != offset:
        raise RuntimeError(F'Expected offset for label {label} to be {offset}, got {x} instead.')
    return AstLabel(offset, line, label)
def statement(self, tokens, in_group)
Expand source code Browse git
def statement(self, tokens: LookAhead, in_group: bool):
    if s := self.label(tokens):
        return s
    if s := self.ifthen(tokens, in_group):
        return s
    if s := self.group(tokens):
        return s
    if s := self.forloop(tokens, in_group):
        return s
    return self.pipeline(tokens, in_group)
def sequence(self, tokens, in_group)
Expand source code Browse git
def sequence(self, tokens: LookAhead, in_group: bool) -> AstSequence | None:
    tokens.skip_space()
    head = self.statement(tokens, in_group)
    if head is None:
        return None
    node = AstSequence(head.offset, head)
    tokens.skip_space()
    while condition := AstCondition.Try(tokens.peek()):
        tokens.pop()
        tokens.skip_space()
        if not (statement := self.statement(tokens, in_group)):
            raise EmulatorException('Failed to parse conditional statement.')
        node.tail.append(
            AstConditionalStatement(statement.offset, condition, statement))
        tokens.skip_space()
    return node
def parse(self, offset)
Expand source code Browse git
def parse(self, offset: int):
    tokens = LookAhead(self.lexer, offset)
    while sequence := self.sequence(tokens, False):
        yield sequence
class BatchState (delayed_expansion=False, extensions_enabled=True, extensions_version=2, environment=None, file_system=None, username='Administrator', hostname=None, now=None, cwd='C:\\')
Expand source code Browse git
class BatchState:

    name: str
    args: list[str]

    environments: list[dict[str, str]]
    delayexpands: list[bool]
    ext_settings: list[bool]
    file_system: dict[str, str]

    def __init__(
        self,
        delayed_expansion: bool = False,
        extensions_enabled: bool = True,
        extensions_version: int = 2,
        environment: dict | None = None,
        file_system: dict | None = None,
        username: str = 'Administrator',
        hostname: str | None = None,
        now: int | float | str | datetime | None = None,
        cwd: str = 'C:\\',
    ):
        self.delayed_expansion = delayed_expansion
        self.extensions_version = extensions_version
        self.extensions_enabled = extensions_enabled
        self.file_system_seed = file_system or {}
        self.environment_seed = environment or {}
        if hostname is None:
            hostname = str(uuid4())
        for key, value in _DEFAULT_ENVIRONMENT.items():
            self.environment_seed.setdefault(
                key.upper(),
                value.format(h=hostname, u=username)
            )
        if isinstance(now, str):
            now = isodate(now)
        if isinstance(now, (int, float)):
            now = date_from_timestamp(now)
        if now is None:
            now = datetime.now()
        self.cwd = cwd
        self.now = now
        self.hostname = hostname
        self.username = username
        seed(self.now.timestamp())
        self.reset()

    @property
    def cwd(self):
        return self._cwd

    @cwd.setter
    def cwd(self, new: str):
        new = new.replace('/', '\\')
        if not new.endswith('\\'):
            new = F'{new}\\'
        if not ntpath.isabs(new):
            new = ntpath.join(self.cwd, new)
        if not ntpath.isabs(new):
            raise ValueError(F'Invalid absolute path: {new}')
        self._cwd = ntpath.normcase(ntpath.normpath(new))

    @property
    def ec(self) -> int:
        return self.errorlevel

    @ec.setter
    def ec(self, value: int | None):
        ec = value or 0
        self.environment['ERRORLEVEL'] = str(ec)
        self.errorlevel = ec

    def reset(self):
        self.labels = {}
        self.environments = [dict(self.environment_seed)]
        self.delayexpands = [self.delayed_expansion]
        self.ext_settings = [self.extensions_enabled]
        self.file_system = dict(self.file_system_seed)
        self.dirstack = []
        self.linebreaks = []
        self.name = F'{uuid4()}.bat'
        self.args = []
        self._cmd = ''
        self.ec = None

    @property
    def command_line(self):
        return self._cmd

    @command_line.setter
    def command_line(self, value: str):
        self._cmd = value
        self.args = value.split()

    def envar(self, name: str) -> str:
        name = name.upper()
        if name in (e := self.environment):
            return e[name]
        elif name == 'DATE':
            return self.now.strftime('%Y-%m-%d')
        elif name == 'TIME':
            time = self.now.strftime('%M:%S,%f')
            return F'{self.now.hour:2d}:{time:.8}'
        elif name == 'RANDOM':
            return str(randrange(0, 32767))
        elif name == 'ERRORLEVEL':
            return str(self.ec)
        elif name == 'CD':
            return self.cwd
        elif name == 'CMDCMDLINE':
            line = self.envar('COMSPEC')
            if args := self.args:
                args = ' '.join(args)
                line = F'{line} /c "{args}"'
            return line
        elif name == 'CMDEXTVERSION':
            return str(self.extensions_version)
        elif name == 'HIGHESTNUMANODENUMBER':
            return '0'
        else:
            return ''

    def _resolved(self, path: str) -> str:
        if not ntpath.isabs(path):
            path = F'{self.cwd}{path}'
        return ntpath.normcase(ntpath.normpath(path))

    def create_file(self, path: str, data: str = ''):
        self.file_system[self._resolved(path)] = data

    def append_file(self, path: str, data: str):
        path = self._resolved(path)
        if left := self.file_system.get(path, None):
            data = F'{left}{data}'
        self.file_system[path] = data

    def remove_file(self, path: str):
        self.file_system.pop(self._resolved(path), None)

    def ingest_file(self, path: str) -> str | None:
        return self.file_system.get(self._resolved(path))

    def exists_file(self, path: str) -> bool:
        return self._resolved(path) in self.file_system

    @property
    def environment(self):
        return self.environments[-1]

    @property
    def delayexpand(self):
        return self.delayexpands[-1]

    @property
    def ext_setting(self):
        return self.ext_settings[-1]

Class variables

var name

The type of the None singleton.

var args

The type of the None singleton.

var environments

The type of the None singleton.

var delayexpands

The type of the None singleton.

var ext_settings

The type of the None singleton.

var file_system

The type of the None singleton.

Instance variables

var cwd
Expand source code Browse git
@property
def cwd(self):
    return self._cwd
var ec
Expand source code Browse git
@property
def ec(self) -> int:
    return self.errorlevel
var command_line
Expand source code Browse git
@property
def command_line(self):
    return self._cmd
var environment
Expand source code Browse git
@property
def environment(self):
    return self.environments[-1]
var delayexpand
Expand source code Browse git
@property
def delayexpand(self):
    return self.delayexpands[-1]
var ext_setting
Expand source code Browse git
@property
def ext_setting(self):
    return self.ext_settings[-1]

Methods

def reset(self)
Expand source code Browse git
def reset(self):
    self.labels = {}
    self.environments = [dict(self.environment_seed)]
    self.delayexpands = [self.delayed_expansion]
    self.ext_settings = [self.extensions_enabled]
    self.file_system = dict(self.file_system_seed)
    self.dirstack = []
    self.linebreaks = []
    self.name = F'{uuid4()}.bat'
    self.args = []
    self._cmd = ''
    self.ec = None
def envar(self, name)
Expand source code Browse git
def envar(self, name: str) -> str:
    name = name.upper()
    if name in (e := self.environment):
        return e[name]
    elif name == 'DATE':
        return self.now.strftime('%Y-%m-%d')
    elif name == 'TIME':
        time = self.now.strftime('%M:%S,%f')
        return F'{self.now.hour:2d}:{time:.8}'
    elif name == 'RANDOM':
        return str(randrange(0, 32767))
    elif name == 'ERRORLEVEL':
        return str(self.ec)
    elif name == 'CD':
        return self.cwd
    elif name == 'CMDCMDLINE':
        line = self.envar('COMSPEC')
        if args := self.args:
            args = ' '.join(args)
            line = F'{line} /c "{args}"'
        return line
    elif name == 'CMDEXTVERSION':
        return str(self.extensions_version)
    elif name == 'HIGHESTNUMANODENUMBER':
        return '0'
    else:
        return ''
def create_file(self, path, data='')
Expand source code Browse git
def create_file(self, path: str, data: str = ''):
    self.file_system[self._resolved(path)] = data
def append_file(self, path, data)
Expand source code Browse git
def append_file(self, path: str, data: str):
    path = self._resolved(path)
    if left := self.file_system.get(path, None):
        data = F'{left}{data}'
    self.file_system[path] = data
def remove_file(self, path)
Expand source code Browse git
def remove_file(self, path: str):
    self.file_system.pop(self._resolved(path), None)
def ingest_file(self, path)
Expand source code Browse git
def ingest_file(self, path: str) -> str | None:
    return self.file_system.get(self._resolved(path))
def exists_file(self, path)
Expand source code Browse git
def exists_file(self, path: str) -> bool:
    return self._resolved(path) in self.file_system