Module refinery.lib.batch

Set Statement

There are two kinds of set statement: The quoted and the unquoted set. A quoted set looks like this:

set "name=var" (...)

It is interpreted as follows:

  • Everything between the first and the last quote in the command is extracted.
  • The resulting string is split at the first equals symbol.
  • The LHS is the variable name.
  • The RHS is unescaped once not respecting quotes, then becomes the variable content.

Examples

set "name="a"^^"b"c echo %name% "a"^"b

set "name="a"^^"b"c echo %name% a"^"b

Note how the trailing c is always discarded because it occurs after the last quote. The unquoted set looks like this:

set name=var

It is parsed as follows:

  • The entire command is parsed and unescaped respecting quotes as usual.
  • The set expression starts with the first non-whitespace character after the set keyword.
  • This expression is split at the first equals symbol.
  • The LHS is the variable name.
  • The RHS is unescaped once respecting quotes, then becomes the variable content.

Input redirection may occur in a set line, basically anywhere:

> set 1>"NUL" "var=val
> echo %var%
val
Expand source code Browse git
"""

## Set Statement

There are two kinds of set statement:
The quoted and the unquoted set.
A quoted set looks like this:

    set "name=var" (...)

It is interpreted as follows:

- Everything between the first and the last quote in the command is extracted.
- The resulting string is split at the first equals symbol.
- The LHS is the variable name.
- The RHS is unescaped once **not** respecting quotes, then becomes the variable content.

Examples:

    > set  "name="a"^^"b"c
    > echo %name%
    "a"^"b

    > set  "name="a"^^"b"c
    > echo %name%
    a"^"b

Note how the trailing c is always discarded because it occurs after the last quote.
The unquoted set looks like this:

    set name=var

It is parsed as follows:

- The entire command is parsed and unescaped respecting quotes as usual.
- The set expression starts with the first non-whitespace character after the set keyword.
- This expression is split at the first equals symbol.
- The LHS is the variable name.
- The RHS is unescaped once respecting quotes, then becomes the variable content.

Input redirection may occur in a set line, basically anywhere:

    > set 1>"NUL" "var=val
    > echo %var%
    val
"""
from __future__ import annotations

from .emulator import BatchEmulator
from .lexer import BatchLexer
from .parser import BatchParser
from .state import BatchState

__all__ = [
    'BatchEmulator',
    'BatchLexer',
    'BatchParser',
    'BatchState',
]

Sub-modules

refinery.lib.batch.const
refinery.lib.batch.emulator
refinery.lib.batch.help
refinery.lib.batch.lexer
refinery.lib.batch.model
refinery.lib.batch.parser
refinery.lib.batch.state
refinery.lib.batch.synth
refinery.lib.batch.util

Classes

class BatchEmulator (data, state=None, cfg=None, std=None)
Expand source code Browse git
class BatchEmulator:

    class _node:
        handlers: ClassVar[dict[
            type[AstNode],
            Callable[[
                BatchEmulator,
                AstNode,
                IO,
                bool,
            ], Generator[SynNodeBase[AstNode] | Error]]
        ]] = {}

        def __init__(self, key: type[AstNode]):
            self.key = key

        def __call__(self, handler):
            self.handlers[self.key] = handler
            return handler

    class _command:
        handlers: ClassVar[dict[
            str,
            Callable[[
                BatchEmulator,
                SynCommand,
                IO,
                bool,
            ], Generator[str, None, int | None] | int | ErrorZero | None]
        ]] = {}

        def __init__(self, key: str):
            self.key = key.upper()

        def __call__(self, handler):
            self.handlers[self.key] = handler
            return handler

    def __init__(
        self,
        data: str | buf | BatchParser,
        state: BatchState | None = None,
        cfg: BatchEmulatorConfig | None = None,
        std: IO | None = None,
    ):
        self.stack = []
        self.parser = BatchParser(data, state)
        self.std = std or IO()
        self.cfg = cfg or BatchEmulatorConfig()
        self.block_labels = set()

    def spawn(self, data: str | buf | BatchParser, state: BatchState | None = None, std: IO | None = None):
        return BatchEmulator(
            data,
            state,
            self.cfg,
            std,
        )

    @property
    def state(self):
        return self.parser.state

    @property
    def environment(self):
        return self.state.environment

    @property
    def delayexpand(self):
        return self.state.delayexpand

    def clone_state(
        self,
        delayexpand: bool | None = None,
        cmdextended: bool | None = None,
        environment: dict | None | ellipsis = ...,
        filename: str | None = None,
    ):
        state = self.state
        if delayexpand is None:
            delayexpand = False
        if cmdextended is None:
            cmdextended = state.cmdextended
        if environment is ...:
            environment = dict(state.environment)
        return BatchState(
            delayexpand,
            cmdextended,
            environment=environment,
            file_system=state.file_system,
            username=state.username,
            hostname=state.hostname,
            now=state.now,
            cwd=state.cwd,
            filename=filename,
        )

    def get_for_variable_regex(self, vars: Iterable[str]):
        return re.compile(RF'%((?:~[fdpnxsatz]*)?)((?:\\$\\w+)?)([{"".join(vars)}])')

    def expand_delayed_variables(self, block: str):
        def expansion(match: re.Match[str]):
            name = match.group(1)
            try:
                return parse(name)
            except MissingVariable:
                _, _, rest = name.partition(':')
                return rest
        parse = self.parser.lexer.parse_env_variable
        return re.sub(r'!([^!\n]*)!', expansion, block)

    def expand_forloop_variables(self, block: str, vars: dict[str, str] | None):
        def expansion(match: re.Match[str]):
            flags = ArgVarFlags.Empty
            for flag in match[1]:
                flags |= ArgVarFlags.FromToken(ord(flag))
            return _vars[match[3]]
        if not vars:
            return block
        _vars = vars
        return self.get_for_variable_regex(vars).sub(expansion, block)

    def contains_for_variable(self, ast: AstNode, vars: Iterable[str]):
        def check(token):
            if isinstance(token, list):
                return any(check(v) for v in token)
            if isinstance(token, dict):
                return any(check(v) for v in token.values())
            if isinstance(token, Enum):
                return False
            if isinstance(token, str):
                return bool(checker(token))
            if isinstance(token, AstNode):
                for tf in fields(token):
                    if tf.name == 'parent':
                        continue
                    if check(getattr(token, tf.name)):
                        return True
            return False
        checker = self.get_for_variable_regex(vars).search
        return check(ast) # type:ignore

    def expand_ast_node(self, ast: _T) -> _T:
        def expand(token):
            if isinstance(token, list):
                return [expand(v) for v in token]
            if isinstance(token, dict):
                return {k: expand(v) for k, v in token.items()}
            if isinstance(token, Enum):
                return token
            if isinstance(token, str):
                if delayexpand:
                    token = self.expand_delayed_variables(token)
                return self.expand_forloop_variables(token, variables)
            if isinstance(token, AstNode):
                new = {}
                for tf in fields(token):
                    value = getattr(token, tf.name)
                    if tf.name != 'parent':
                        value = expand(value)
                    new[tf.name] = value
                return token.__class__(**new)
            return token
        delayexpand = self.delayexpand
        variables = self.state.for_loop_variables
        if not variables and not delayexpand:
            return ast
        return expand(ast) # type:ignore

    def execute_find_or_findstr(self, cmd: SynCommand, std: IO, findstr: bool):
        needles = []
        paths: list[str | ellipsis] = [...]
        flags = {}
        it = iter(cmd.args)
        arg = None
        yield cmd

        for arg in it:
            if not arg.startswith('/'):
                if not findstr and not arg.startswith('"'):
                    return 1
                needles.extend(unquote(arg).split())
                break
            name, has_param, value = arg[1:].partition(':')
            name = name.upper()
            if name in ('OFF', 'OFFLINE'):
                continue
            elif len(name) > 1:
                return 1
            elif name == 'C':
                needles.append(unquote(value))
            elif name == 'F' and findstr:
                if (p := self.state.ingest_file(value)) is None:
                    return 1
                paths.extend(p.splitlines(False))
            elif name == 'G' and findstr:
                if (n := self.state.ingest_file(value)) is None:
                    return 1
                needles.extend(n.splitlines(False))
            elif has_param:
                flags[name] = value
            else:
                flags[name] = True

        valid_flags = 'VNI'
        if findstr:
            valid_flags += 'BELRSXMOPADQ'

        for v in flags:
            if v not in valid_flags:
                return 1

        prefix_filename = False
        state = self.state

        for arg in it:
            pattern = unquote(arg)
            if '*' in pattern or '?' in pattern:
                prefix_filename = True
                for path in state.file_system:
                    if winfnmatch(path, pattern, state.cwd):
                        paths.append(path)
            else:
                paths.append(pattern)

        if len(paths) > 1:
            prefix_filename = True

        for n, needle in enumerate(needles):
            if not findstr or 'L' in flags:
                needle = re.escape(needle)
            if 'X' in flags:
                needle = F'^{needle}$'
            elif 'B' in flags:
                needle = F'^{needle}'
            elif 'E' in flags:
                needle = F'{needle}$'
            needles[n] = needle

        _V = 'V' in flags # noqa; Prints only lines that do not contain a match.
        _P = 'P' in flags # noqa; Skip files with non-printable characters.
        _O = 'O' in flags # noqa; Prints character offset before each matching line.
        _N = 'N' in flags # noqa; Prints the line number before each line that matches.
        _M = 'M' in flags # noqa; Prints only the filename if a file contains a match.

        nothing_found = True
        offset = 0

        for path in paths:
            if path is (...):
                data = std.i.read()
            else:
                data = state.ingest_file(path)
            if data is None:
                return 1
            if _P and not re.fullmatch('[\\s!-~]+', data):
                continue
            for n, line in enumerate(data.splitlines(True), 1):
                for needle in needles:
                    hit = re.search(needle, line)
                    if _V == bool(hit):
                        continue
                    nothing_found = False
                    if not _M:
                        if _O:
                            o = offset + (hit.start() if hit else 0)
                            line = F'{o}:{line}'
                        if _N:
                            line = F'{n}:{line}'
                        if prefix_filename:
                            line = F'{path}:{line}'
                        std.o.write(line)
                    elif path is not (...):
                        std.o.write(path)
                        break
                offset += len(line)

        return int(nothing_found)

    @_command('TYPE')
    def execute_type(self, cmd: SynCommand, std: IO, *_):
        path = cmd.argument_string.strip()
        data = self.state.ingest_file(path)
        if data is None:
            yield ErrorCannotFindFile
            return 1
        else:
            std.o.write(data)
            return 0

    @_command('FIND')
    def execute_find(self, cmd: SynCommand, std: IO, *_):
        return self.execute_find_or_findstr(cmd, std, findstr=False)

    @_command('FINDSTR')
    def execute_findstr(self, cmd: SynCommand, std: IO, *_):
        return self.execute_find_or_findstr(cmd, std, findstr=True)

    @_command('SET')
    def execute_set(self, cmd: SynCommand, std: IO, *_):
        if not (args := cmd.args):
            raise EmulatorException('Empty SET instruction')

        if cmd.verb.upper() != 'SET':
            raise RuntimeError

        # Since variables can be used in GOTO, a SET can be used to change the behavior of a GOTO.
        self.block_labels.clear()

        arithmetic = False
        quote_mode = False
        prompt = None

        it = iter(args)
        tk = next(it)

        if tk.upper() == '/P':
            if std.i.closed:
                prompt = ''
            elif not (prompt := std.i.readline()).endswith('\n'):
                raise InputLocked
            else:
                prompt = prompt.rstrip('\r\n')
            tk = next(it)
        else:
            cmd.junk = not self.cfg.show_sets

        yield cmd

        if tk.upper() == '/A':
            arithmetic = True
            try:
                tk = next(it)
            except StopIteration:
                tk = ''

        args = [tk, *it, *cmd.trailing_spaces]

        if arithmetic:
            def defang(s: str):
                def r(m: re.Match[str]):
                    return F'_{prefix}{ord(m[0]):X}_'
                return re.sub(r'[^-\s()!~*/%+><&^|_\w]', r, s)
            def refang(s: str): # noqa
                def r(m: re.Match[str]):
                    return chr(int(m[1], 16))
                return re.sub(rf'_{prefix}([A-F0-9]+)_', r, s)
            prefix = F'{uuid.uuid4().time_mid:X}'
            namespace = {}
            translate = {}
            value = None
            if not (program := ''.join(args)):
                std.e.write('The syntax of the command is incorrect.\r\n')
                return ErrorZero.Val
            for assignment in program.split(','):
                assignment = assignment.strip()
                if not assignment:
                    std.e.write('Missing operand.\r\n')
                    return ErrorZero.Val
                name, operator, definition = re.split(r'([*+^|/%-&]|<<|>>|)=', assignment, maxsplit=1)
                name = name.upper()
                definition = re.sub(r'\b0([0-7]+)\b', r'0o\1', definition)
                if operator:
                    definition = F'{name}{operator}({definition})'
                definition = defang(definition)
                expression = cautious_parse(definition)
                names = names_in_expression(expression)
                if names.stored or names.others:
                    raise EmulatorException('Arithmetic SET had unexpected variable access.')
                for var in names.loaded:
                    original = refang(name).upper()
                    translate[original] = var
                    if var in namespace:
                        continue
                    try:
                        namespace[var] = batchint(self.environment[original])
                    except (KeyError, ValueError):
                        namespace[var] = 0
                code = compile(expression, filename='[ast]', mode='eval')
                value = eval(code, namespace, {})
                self.environment[name] = str(value)
                namespace[defang(name)] = value
            if value is None:
                std.e.write('The syntax of the command is incorrect.')
                return
            else:
                std.o.write(F'{value!s}\r\n')
        else:
            try:
                eq = args.index(Ctrl.Equals)
            except ValueError:
                assignment = cmd.argument_string
                if assignment.startswith('"'):
                    quote_mode = True
                    assignment, _, unquoted = assignment[1:].rpartition('"')
                    assignment = assignment or unquoted
                else:
                    assignment = ''.join(args)
                name, _, content = assignment.partition('=')
            else:
                with StringIO() as io:
                    for k in range(eq + 1, len(args)):
                        io.write(args[k])
                    content = io.getvalue()
                    name = cmd.args[eq - 1] if eq else ''
            name = name.upper()
            trailing_caret, content = uncaret(content, quote_mode)
            if trailing_caret:
                content = content[:-1]
            if prompt is not None:
                if (qc := content.strip()).startswith('"'):
                    _, _, qc = qc. partition('"') # noqa
                    qc, _, r = qc.rpartition('"') # noqa
                    content = qc or r
                std.o.write(content)
                content = prompt
            if name:
                if content:
                    self.environment[name] = content
                else:
                    self.environment.pop(name, None)

    @_command('CALL')
    def execute_call(self, cmd: SynCommand, std: IO, *_):
        cmdl = cmd.argument_string
        empty, colon, label = cmdl.partition(':')
        if colon and not empty:
            try:
                offset = self.parser.lexer.labels[label.upper()]
            except KeyError as KE:
                raise InvalidLabel(label) from KE
            emu = self.spawn(self.parser, std=std)
        else:
            offset = 0
            path = cmdl.strip()
            code = self.state.ingest_file(path)
            if code is None:
                yield cmd
                return
            state = self.clone_state(environment=self.state.environment, filename=path)
            emu = self.spawn(code, std=std, state=state)
        if self.cfg.skip_call:
            emu.execute(called=True)
        else:
            yield from emu.trace(offset, called=True)

    @_command('SETLOCAL')
    def execute_setlocal(self, cmd: SynCommand, *_):
        yield cmd
        setting = cmd.argument_string.strip().upper()
        delay = {
            'DISABLEDELAYEDEXPANSION': False,
            'ENABLEDELAYEDEXPANSION' : True,
        }.get(setting, self.state.delayexpand)
        cmdxt = {
            'DISABLEEXTENSIONS': False,
            'ENABLEEXTENSIONS' : True,
        }.get(setting, self.state.cmdextended)
        self.state.delayexpand_stack.append(delay)
        self.state.cmdextended_stack.append(cmdxt)
        self.state.environment_stack.append(dict(self.environment))

    @_command('ENDLOCAL')
    def execute_endlocal(self, cmd: SynCommand, *_):
        yield cmd
        if len(self.state.environment_stack) > 1:
            self.state.environment_stack.pop()
            self.state.delayexpand_stack.pop()

    @_command('GOTO')
    def execute_goto(self, cmd: SynCommand, std: IO, *_):
        if self.cfg.skip_goto:
            yield cmd
            return
        it = iter(cmd.args)
        mark = False
        for label in it:
            if not isinstance(label, Ctrl):
                break
            if label == Ctrl.Label:
                mark = True
                for label in it:
                    break
                else:
                    label = ''
                break
        else:
            std.e.write('No batch label specified to GOTO command.\r\n')
            raise AbortExecution
        label, *_ = label.split(maxsplit=1)
        key = label.upper()
        if mark and key == 'EOF':
            raise Exit(int(self.state.ec), False)
        if key not in self.block_labels:
            raise Goto(label)
        else:
            yield Error(F'Infinite Loop detected for label {key}')

    @_command('EXIT')
    def execute_exit(self, cmd: SynCommand, *_):
        it = iter(cmd.args)
        exit = True
        token = 0
        for arg in it:
            if arg.upper() == '/B':
                exit = False
                continue
            token = arg
            break
        try:
            code = int(token)
        except ValueError:
            code = 0
        yield cmd
        if self.cfg.skip_exit:
            return
        raise Exit(code, exit)

    @_command('CHDIR')
    @_command('CD')
    def execute_chdir(self, cmd: SynCommand, *_):
        yield cmd
        self.state.cwd = cmd.argument_string.strip()

    @_command('PUSHD')
    def execute_pushd(self, cmd: SynCommand, *_):
        yield cmd
        self.state.dirstack.append(self.state.cwd)
        self.execute_chdir(cmd)

    @_command('POPD')
    def execute_popd(self, cmd: SynCommand, *_):
        yield cmd
        try:
            self.state.cwd = self.state.dirstack.pop()
        except IndexError:
            pass

    @_command('ECHO')
    def execute_echo(self, cmd: SynCommand, std: IO, in_group: bool):
        cmdl = cmd.argument_string
        mode = cmdl.strip().lower()
        current_state = self.state.echo
        if mode == 'on':
            if self.cfg.show_nops or current_state is False:
                yield cmd
            self.state.echo = True
            return
        if mode == 'off':
            if self.cfg.show_nops or current_state is True:
                yield cmd
            self.state.echo = False
            return
        yield cmd
        if mode:
            if in_group and not cmdl.endswith(' '):
                cmdl += ' '
            std.o.write(F'{cmdl}\r\n')
        else:
            mode = 'on' if self.state.echo else 'off'
            std.o.write(F'ECHO is {mode}.\r\n')

    @_command('CLS')
    def execute_cls(self, cmd: SynCommand, *_):
        yield cmd

    @_command('ERASE')
    @_command('DEL')
    def execute_del(self, cmd: SynCommand, std: IO, *_):
        if not cmd.args:
            yield Error('The syntax of the command is incorrect')
            return 1
        else:
            yield cmd
        flags = {}
        it = iter(cmd.args)
        while (arg := next(it)).startswith('/') and 1 < len(arg):
            flag = arg.upper()
            if flag[:3] == '/A:':
                flags['A'] = flag[3:]
                continue
            flags[flag[1]] = True
        _P = 'P' in flags # Prompts for confirmation before deleting each file.
        _F = 'F' in flags # Force deleting of read-only files.
        _S = 'S' in flags # Delete specified files from all subdirectories.
        _Q = 'Q' in flags # Quiet mode, do not ask if ok to delete on global wildcard
        paths = [arg, *it]
        state = self.state
        cwd = state.cwd
        for pattern in paths:
            for path in list(state.file_system):
                if not winfnmatch(pattern, path, cwd):
                    continue
                if _F:
                    pass
                if _S:
                    pass
                if _Q:
                    pass
                if _P and state.exists_file(pattern):
                    std.o.write(F'{pattern}, Delete (Y/N)? ')
                    decision = None
                    while decision not in ('y', 'n'):
                        confirmation = std.i.readline()
                        if not confirmation.endswith('\n'):
                            raise InputLocked
                        decision = confirmation[:1].lower()
                    if decision == 'n':
                        continue
                state.remove_file(path)
        return 0

    @_command('START')
    def execute_start(self, cmd: SynCommand, std: IO, *_):
        yield cmd
        it = iter(cmd.ast.fragments)
        it = itertools.islice(it, cmd.argument_offset, None)
        title = None
        start = None
        cwd = self.state.cwd
        env = ...
        for arg in it:
            if title is None:
                if '"' not in arg:
                    title = ''
                else:
                    title = unquote(arg)
                    continue
            if arg.isspace():
                continue
            if not arg.startswith('/'):
                start = unquote(arg)
                break
            if (flag := arg.upper()) in ('/NODE', '/AFFINITY', '/MACHINE'):
                next(it)
            elif flag == '/D':
                cwd = next(it)
            elif flag == '/I':
                env = None
        if start and (batch := self.state.ingest_file(start)):
            state = self.clone_state(environment=env)
            state.cwd = cwd
            state.command_line = _fuse(it).strip()
            shell = self.spawn(batch, state, std)
            yield from shell.trace()

    @_command('CMD')
    def execute_cmd(self, cmd: SynCommand, std: IO, *_):
        yield cmd
        it = iter(cmd.ast.fragments)
        command = None
        quiet = False
        strip = False
        codec = 'cp1252'
        delayexpand = None
        cmdextended = None

        for arg in it:
            if arg.isspace() or not arg.startswith('/'):
                continue
            name, _, flag = arg[1:].partition(':')
            flag = flag.upper()
            name = name.upper()
            if name in 'CKR':
                command = _fuse(it)
                break
            elif name == 'Q':
                quiet = True
            elif name == 'S':
                strip = True
            elif name == 'U':
                codec = 'utf-16le'
            elif name == 'E':
                cmdextended = _onoff(flag)
            elif name == 'V':
                delayexpand = _onoff(flag)
        else:
            return 0

        if (stripped := re.search('^\\s*"(.*)"', command)) and (strip
            or command.count('"') != 2
            or re.search('[&<>()@^|]', stripped[1])
            or re.search('\\s', stripped[1]) is None
        ):
            command = stripped[1]

        state = self.clone_state(delayexpand=delayexpand, cmdextended=cmdextended)
        state.codec = codec
        state.echo = not quiet
        shell = self.spawn(command, state, std)
        yield from shell.trace()

    @_command('ARP')
    @_command('AT')
    @_command('ATBROKER')
    @_command('BGINFO')
    @_command('BITSADMIN')
    @_command('CERTUTIL')
    @_command('CLIP')
    @_command('CMSTP')
    @_command('COMPACT')
    @_command('CONTROL')
    @_command('CSCRIPT')
    @_command('CURL')
    @_command('DEFRAG')
    @_command('DISKSHADOW')
    @_command('ESENTUTL')
    @_command('EXPAND')
    @_command('EXPLORER')
    @_command('EXTRAC32')
    @_command('FODHELPER')
    @_command('FORFILES')
    @_command('FTP')
    @_command('HOSTNAME')
    @_command('HOSTNAME')
    @_command('INSTALLUTIL')
    @_command('IPCONFIG')
    @_command('LOGOFF')
    @_command('MAKECAB')
    @_command('MAVINJECT')
    @_command('MOUNTVOL')
    @_command('MSBUILD')
    @_command('MSHTA')
    @_command('MSIEXEC')
    @_command('MSTSC')
    @_command('NET')
    @_command('NET1')
    @_command('NETSH')
    @_command('NSLOOKUP')
    @_command('ODBCCONF')
    @_command('PATHPING')
    @_command('PING')
    @_command('POWERSHELL')
    @_command('PRESENTATIONHOST')
    @_command('PWSH')
    @_command('REG')
    @_command('REGSVR32')
    @_command('ROUTE')
    @_command('RUNDLL32')
    @_command('SCP')
    @_command('SDCLT')
    @_command('SETX')
    @_command('SFTP')
    @_command('SHUTDOWN')
    @_command('SSH')
    @_command('SUBST')
    @_command('SYNCAPPVPUBLISHINGSERVER')
    @_command('SYSTEMINFO')
    @_command('TAR')
    @_command('TELNET')
    @_command('TFTP')
    @_command('TIMEOUT')
    @_command('TRACERT')
    @_command('VSSADMIN')
    @_command('WBADMIN')
    @_command('WHERE')
    @_command('WHOAMI')
    @_command('WINRM')
    @_command('WINRS')
    @_command('WSCRIPT')
    def execute_unimplemented_program(self, cmd: SynCommand, *_):
        yield cmd
        return 0

    @_command('CLS')
    def execute_unimplemented_command_unmodified_ec(self, cmd: SynCommand, *_):
        yield cmd

    @_command('ASSOC')
    @_command('ATTRIB')
    @_command('BCDEDIT')
    @_command('BREAK')
    @_command('CACLS')
    @_command('CHCP')
    @_command('CHKDSK')
    @_command('CHKNTFS')
    @_command('COLOR')
    @_command('COMP')
    @_command('COMPACT')
    @_command('CONVERT')
    @_command('COPY')
    @_command('DATE')
    @_command('DIR')
    @_command('DISKPART')
    @_command('DOSKEY')
    @_command('DRIVERQUERY')
    @_command('FC')
    @_command('FORMAT')
    @_command('FSUTIL')
    @_command('FTYPE')
    @_command('GPRESULT')
    @_command('ICACLS')
    @_command('LABEL')
    @_command('MD')
    @_command('MKDIR')
    @_command('MKLINK')
    @_command('MODE')
    @_command('MORE')
    @_command('MOVE')
    @_command('OPENFILES')
    @_command('PATH')
    @_command('PAUSE')
    @_command('PRINT')
    @_command('PROMPT')
    @_command('RD')
    @_command('RECOVER')
    @_command('REN')
    @_command('RENAME')
    @_command('REPLACE')
    @_command('RMDIR')
    @_command('ROBOCOPY')
    @_command('SC')
    @_command('SCHTASKS')
    @_command('SHIFT')
    @_command('SHUTDOWN')
    @_command('SORT')
    @_command('SUBST')
    @_command('SYSTEMINFO')
    @_command('TASKKILL')
    @_command('TASKLIST')
    @_command('TIME')
    @_command('TITLE')
    @_command('TREE')
    @_command('TYPE')
    @_command('VER')
    @_command('VERIFY')
    @_command('VOL')
    @_command('WMIC')
    @_command('XCOPY')
    def execute_unimplemented_command(self, cmd: SynCommand, *_):
        yield cmd
        return 0

    @_command('REM')
    def execute_rem(self, cmd: SynCommand, *_):
        if self.cfg.show_comments:
            yield cmd

    @_command('HELP')
    def execute_help(self, cmd: SynCommand, std: IO, *_):
        yield cmd
        std.o.write(HelpOutput['HELP'])
        return 0

    def execute_command(self, cmd: SynCommand, std: IO, in_group: bool):
        verb = cmd.verb.upper().strip()
        handler = self._command.handlers.get(verb)

        if handler is None:
            base, ext = ntpath.splitext(verb)
            handler = None
            if any(ext == pe.upper() for pe in self.state.envar('PATHEXT', '').split(';')):
                handler = self._command.handlers.get(base)

        if handler is None:
            if self.state.exists_file(verb):
                self.state.ec = 0
            elif not indicators.winfpath.value.fullmatch(verb):
                if '\uFFFD' in verb or not verb.isprintable():
                    self.state.ec = 9009
                    cmd.junk = True
                else:
                    cmd.junk = not self.cfg.show_junk
            yield cmd
            return

        paths: dict[int, str] = {}

        for src, r in cmd.ast.redirects.items():
            if not 0 <= src <= 2 or (src == 0) != r.is_input:
                continue
            if isinstance((target := r.target), str):
                if target.upper() == 'NUL':
                    std[src] = DevNull()
                else:
                    data = self.state.ingest_file(target)
                    if src == 0:
                        if data is None:
                            yield ErrorCannotFindFile
                            return
                        std.i = StringIO(data)
                    else:
                        if r.is_out_append:
                            buffer = StringIO(data)
                            buffer.seek(0, 2)
                        else:
                            buffer = StringIO()
                        std[src] = buffer
                        paths[src] = target
            elif src == 1 and target == 2:
                std.o = std.e
            elif src == 2 and target == 1:
                std.e = std.o

        if '/?' in cmd.args:
            std.o.write(HelpOutput[verb])
            self.state.ec = 0
            return

        if (result := handler(self, cmd, std, in_group)) is None:
            pass
        elif not isinstance(result, (int, ErrorZero)):
            result = (yield from result)

        for k, path in paths.items():
            self.state.create_file(path, std[k].getvalue())

        if result is not None:
            self.state.ec = result

    @_node(AstPipeline)
    def trace_pipeline(self, pipeline: AstPipeline, std: IO, in_group: bool):
        length = len(pipeline.parts)
        streams = IO(*std)
        if length > 1:
            yield synthesize(pipeline)
        for k, part in enumerate(pipeline.parts, 1):
            if k != 1:
                streams.i = streams.o
                streams.i.seek(0)
            if k == length:
                streams.o = std.o
            else:
                streams.o = StringIO()
            if isinstance(part, AstGroup):
                it = self.trace_group(part, streams, in_group)
            else:
                ast = self.expand_ast_node(part)
                cmd = synthesize(ast)
                it = self.execute_command(cmd, streams, in_group)
            yield from it

    @_node(AstSequence)
    def trace_sequence(self, sequence: AstSequence, std: IO, in_group: bool):
        yield from self.trace_statement(sequence.head, std, in_group)
        for cs in sequence.tail:
            if cs.condition == AstCondition.Failure:
                if bool(self.state.ec) is False:
                    continue
            if cs.condition == AstCondition.Success:
                if bool(self.state.ec) is True:
                    continue
            yield from self.trace_statement(cs.statement, std, in_group)

    @_node(AstIf)
    def trace_if(self, _if: AstIf, std: IO, in_group: bool):
        yield synthesize(_if)
        _if = self.expand_ast_node(_if)
        self.block_labels.clear()

        if _if.variant == AstIfVariant.ErrorLevel:
            condition = _if.var_int <= self.state.ec
        elif _if.variant == AstIfVariant.CmdExtVersion:
            condition = _if.var_int <= self.state.extensions_version
        elif _if.variant == AstIfVariant.Exist:
            condition = self.state.exists_file(_if.var_str)
        elif _if.variant == AstIfVariant.Defined:
            condition = _if.var_str.upper() in self.state.environment
        else:
            lhs = _if.lhs
            rhs = _if.rhs
            cmp = _if.cmp
            assert lhs is not None
            assert rhs is not None
            if cmp == AstIfCmp.STR:
                if _if.casefold:
                    if isinstance(lhs, str):
                        lhs = lhs.casefold()
                    if isinstance(rhs, str):
                        rhs = rhs.casefold()
                condition = lhs == rhs
            elif cmp == AstIfCmp.GTR:
                condition = lhs > rhs
            elif cmp == AstIfCmp.GEQ:
                condition = lhs >= rhs
            elif cmp == AstIfCmp.NEQ:
                condition = lhs != rhs
            elif cmp == AstIfCmp.EQU:
                condition = lhs == rhs
            elif cmp == AstIfCmp.LSS:
                condition = lhs < rhs
            elif cmp == AstIfCmp.LEQ:
                condition = lhs <= rhs
            else:
                raise RuntimeError(cmp)
        if _if.negated:
            condition = not condition

        if condition:
            yield from self.trace_sequence(_if.then_do, std, in_group)
        elif (_else := _if.else_do):
            yield from self.trace_sequence(_else, std, in_group)

    @_node(AstFor)
    def trace_for(self, _for: AstFor, std: IO, in_group: bool):
        state = self.state
        cwd = state.cwd
        vars = state.new_forloop()
        body = _for.body
        name = _for.variable
        vars[name] = ''

        if (
            self.contains_for_variable(body, vars)
                or _for.variant != AstForVariant.NumericLoop
                or len(_for.spec) != 1
        ):
            yield synthesize(_for)

        if _for.variant == AstForVariant.FileParsing:
            if _for.mode == AstForParserMode.Command:
                emulator = self.spawn(_for.specline, self.clone_state(filename=state.name))
                yield from emulator.trace()
                lines = emulator.std.o.getvalue().splitlines()
            elif _for.mode == AstForParserMode.Literal:
                lines = _for.spec
            else:
                def lines_from_files():
                    fs = state.file_system
                    for name in _for.spec:
                        for path, content in fs.items():
                            if not winfnmatch(path, name, cwd):
                                continue
                            yield from content.splitlines(False)
                lines = lines_from_files()
            opt = _for.options
            tokens = sorted(opt.tokens)
            split = re.compile('[{}]+'.format(re.escape(opt.delims)))
            count = tokens[-1] + 1 if tokens else 0
            first_variable = ord(name)
            if opt.asterisk:
                tokens.append(count)
            for n, line in enumerate(lines):
                if n < opt.skip:
                    continue
                if opt.comment and line.startswith(opt.comment):
                    continue
                if count:
                    tokenized = split.split(line, maxsplit=count)
                else:
                    tokenized = (line,)
                for k, tok in enumerate(tokens):
                    name = chr(first_variable + k)
                    if not name.isalpha():
                        raise EmulatorException('Ran out of variables in FOR-Loop.')
                    try:
                        vars[name] = tokenized[tok]
                    except IndexError:
                        vars[name] = ''
                yield from self.trace_sequence(body, std, in_group)
        else:
            for entry in _for.spec:
                vars[name] = entry
                yield from self.trace_sequence(body, std, in_group)
        state.end_forloop()

    @_node(AstGroup)
    def trace_group(self, group: AstGroup, std: IO, in_group: bool):
        for sequence in group.fragments:
            yield from self.trace_sequence(sequence, std, True)
        yield synthesize(group)

    @_node(AstLabel)
    def trace_label(self, label: AstLabel, *_):
        if label.comment:
            if self.cfg.show_comments:
                yield synthesize(label)
        else:
            if self.cfg.show_labels:
                yield synthesize(label)
            self.block_labels.add(label.label.upper())

    def trace_statement(self, statement: AstStatement, std: IO, in_group: bool):
        try:
            handler = self._node.handlers[statement.__class__]
        except KeyError:
            raise RuntimeError(statement)
        yield from handler(self, statement, std, in_group)

    def emulate_commands(self, allow_junk=False):
        for syn in self.trace():
            if not isinstance(syn, SynCommand):
                continue
            if not allow_junk and syn.junk:
                continue
            yield str(syn)

    def emulate_to_depth(self, depth: int = 0):
        for syn in self.trace():
            if not isinstance(syn, SynNodeBase):
                continue
            if syn.ast.depth <= depth:
                yield str(syn)

    def emulate(self, offset: int = 0):
        last: AstNode | None = None
        junk: AstNode | None = None
        for syn in self.trace(offset):
            if not isinstance(syn, SynNodeBase):
                continue
            ast = syn.ast
            if isinstance(syn, SynCommand) and syn.junk:
                junk = ast
                continue
            if junk is not None:
                if junk.is_descendant_of(ast):
                    if not last or not last.is_descendant_of(ast):
                        continue
            if last is not None:
                if ast.is_descendant_of(last):
                    # we already synthesized a parent construct, like a FOR loop or IF block
                    continue
                if last.is_descendant_of(ast):
                    # we synthesized a command and no longer need to synthesize an AST node that
                    # wraps it, like a group
                    continue
            if isinstance(ast, AstPipeline):
                if len(ast.parts) == 1:
                    continue
            if last is ast:
                raise RuntimeError('Emulator attempted to synthesize the same command twice.')
            last = ast
            yield str(syn)

    def execute(self, offset: int = 0, called: bool = False):
        for _ in self.trace(offset, called=called):
            pass

    def trace(self, offset: int = 0, called: bool = False):
        if (name := self.state.name):
            self.state.create_file(name, self.parser.lexer.text)
        length = len(self.parser.lexer.code)
        labels = self.parser.lexer.labels

        while offset < length:
            try:
                for sequence in self.parser.parse(offset):
                    if isinstance(sequence, AstError):
                        yield Error(sequence.error)
                        continue
                    yield from self.trace_sequence(sequence, self.std, False)
            except Goto as goto:
                try:
                    offset = labels[goto.label.upper()]
                except KeyError:
                    raise InvalidLabel(goto.label) from goto
                continue
            except Exit as exit:
                self.state.ec = exit.code
                if exit.exit and called:
                    raise
                else:
                    break
            except AbortExecution:
                self.state.ec = 1
                break
            else:
                break

Instance variables

var state
Expand source code Browse git
@property
def state(self):
    return self.parser.state
var environment
Expand source code Browse git
@property
def environment(self):
    return self.state.environment
var delayexpand
Expand source code Browse git
@property
def delayexpand(self):
    return self.state.delayexpand

Methods

def spawn(self, data, state=None, std=None)
Expand source code Browse git
def spawn(self, data: str | buf | BatchParser, state: BatchState | None = None, std: IO | None = None):
    return BatchEmulator(
        data,
        state,
        self.cfg,
        std,
    )
def clone_state(self, delayexpand=None, cmdextended=None, environment=Ellipsis, filename=None)
Expand source code Browse git
def clone_state(
    self,
    delayexpand: bool | None = None,
    cmdextended: bool | None = None,
    environment: dict | None | ellipsis = ...,
    filename: str | None = None,
):
    state = self.state
    if delayexpand is None:
        delayexpand = False
    if cmdextended is None:
        cmdextended = state.cmdextended
    if environment is ...:
        environment = dict(state.environment)
    return BatchState(
        delayexpand,
        cmdextended,
        environment=environment,
        file_system=state.file_system,
        username=state.username,
        hostname=state.hostname,
        now=state.now,
        cwd=state.cwd,
        filename=filename,
    )
def get_for_variable_regex(self, vars)
Expand source code Browse git
def get_for_variable_regex(self, vars: Iterable[str]):
    return re.compile(RF'%((?:~[fdpnxsatz]*)?)((?:\\$\\w+)?)([{"".join(vars)}])')
def expand_delayed_variables(self, block)
Expand source code Browse git
def expand_delayed_variables(self, block: str):
    def expansion(match: re.Match[str]):
        name = match.group(1)
        try:
            return parse(name)
        except MissingVariable:
            _, _, rest = name.partition(':')
            return rest
    parse = self.parser.lexer.parse_env_variable
    return re.sub(r'!([^!\n]*)!', expansion, block)
def expand_forloop_variables(self, block, vars)
Expand source code Browse git
def expand_forloop_variables(self, block: str, vars: dict[str, str] | None):
    def expansion(match: re.Match[str]):
        flags = ArgVarFlags.Empty
        for flag in match[1]:
            flags |= ArgVarFlags.FromToken(ord(flag))
        return _vars[match[3]]
    if not vars:
        return block
    _vars = vars
    return self.get_for_variable_regex(vars).sub(expansion, block)
def contains_for_variable(self, ast, vars)
Expand source code Browse git
def contains_for_variable(self, ast: AstNode, vars: Iterable[str]):
    def check(token):
        if isinstance(token, list):
            return any(check(v) for v in token)
        if isinstance(token, dict):
            return any(check(v) for v in token.values())
        if isinstance(token, Enum):
            return False
        if isinstance(token, str):
            return bool(checker(token))
        if isinstance(token, AstNode):
            for tf in fields(token):
                if tf.name == 'parent':
                    continue
                if check(getattr(token, tf.name)):
                    return True
        return False
    checker = self.get_for_variable_regex(vars).search
    return check(ast) # type:ignore
def expand_ast_node(self, ast)
Expand source code Browse git
def expand_ast_node(self, ast: _T) -> _T:
    def expand(token):
        if isinstance(token, list):
            return [expand(v) for v in token]
        if isinstance(token, dict):
            return {k: expand(v) for k, v in token.items()}
        if isinstance(token, Enum):
            return token
        if isinstance(token, str):
            if delayexpand:
                token = self.expand_delayed_variables(token)
            return self.expand_forloop_variables(token, variables)
        if isinstance(token, AstNode):
            new = {}
            for tf in fields(token):
                value = getattr(token, tf.name)
                if tf.name != 'parent':
                    value = expand(value)
                new[tf.name] = value
            return token.__class__(**new)
        return token
    delayexpand = self.delayexpand
    variables = self.state.for_loop_variables
    if not variables and not delayexpand:
        return ast
    return expand(ast) # type:ignore
def execute_find_or_findstr(self, cmd, std, findstr)
Expand source code Browse git
def execute_find_or_findstr(self, cmd: SynCommand, std: IO, findstr: bool):
    needles = []
    paths: list[str | ellipsis] = [...]
    flags = {}
    it = iter(cmd.args)
    arg = None
    yield cmd

    for arg in it:
        if not arg.startswith('/'):
            if not findstr and not arg.startswith('"'):
                return 1
            needles.extend(unquote(arg).split())
            break
        name, has_param, value = arg[1:].partition(':')
        name = name.upper()
        if name in ('OFF', 'OFFLINE'):
            continue
        elif len(name) > 1:
            return 1
        elif name == 'C':
            needles.append(unquote(value))
        elif name == 'F' and findstr:
            if (p := self.state.ingest_file(value)) is None:
                return 1
            paths.extend(p.splitlines(False))
        elif name == 'G' and findstr:
            if (n := self.state.ingest_file(value)) is None:
                return 1
            needles.extend(n.splitlines(False))
        elif has_param:
            flags[name] = value
        else:
            flags[name] = True

    valid_flags = 'VNI'
    if findstr:
        valid_flags += 'BELRSXMOPADQ'

    for v in flags:
        if v not in valid_flags:
            return 1

    prefix_filename = False
    state = self.state

    for arg in it:
        pattern = unquote(arg)
        if '*' in pattern or '?' in pattern:
            prefix_filename = True
            for path in state.file_system:
                if winfnmatch(path, pattern, state.cwd):
                    paths.append(path)
        else:
            paths.append(pattern)

    if len(paths) > 1:
        prefix_filename = True

    for n, needle in enumerate(needles):
        if not findstr or 'L' in flags:
            needle = re.escape(needle)
        if 'X' in flags:
            needle = F'^{needle}$'
        elif 'B' in flags:
            needle = F'^{needle}'
        elif 'E' in flags:
            needle = F'{needle}$'
        needles[n] = needle

    _V = 'V' in flags # noqa; Prints only lines that do not contain a match.
    _P = 'P' in flags # noqa; Skip files with non-printable characters.
    _O = 'O' in flags # noqa; Prints character offset before each matching line.
    _N = 'N' in flags # noqa; Prints the line number before each line that matches.
    _M = 'M' in flags # noqa; Prints only the filename if a file contains a match.

    nothing_found = True
    offset = 0

    for path in paths:
        if path is (...):
            data = std.i.read()
        else:
            data = state.ingest_file(path)
        if data is None:
            return 1
        if _P and not re.fullmatch('[\\s!-~]+', data):
            continue
        for n, line in enumerate(data.splitlines(True), 1):
            for needle in needles:
                hit = re.search(needle, line)
                if _V == bool(hit):
                    continue
                nothing_found = False
                if not _M:
                    if _O:
                        o = offset + (hit.start() if hit else 0)
                        line = F'{o}:{line}'
                    if _N:
                        line = F'{n}:{line}'
                    if prefix_filename:
                        line = F'{path}:{line}'
                    std.o.write(line)
                elif path is not (...):
                    std.o.write(path)
                    break
            offset += len(line)

    return int(nothing_found)
def execute_type(self, cmd, std, *_)
Expand source code Browse git
@_command('TYPE')
def execute_type(self, cmd: SynCommand, std: IO, *_):
    path = cmd.argument_string.strip()
    data = self.state.ingest_file(path)
    if data is None:
        yield ErrorCannotFindFile
        return 1
    else:
        std.o.write(data)
        return 0
def execute_find(self, cmd, std, *_)
Expand source code Browse git
@_command('FIND')
def execute_find(self, cmd: SynCommand, std: IO, *_):
    return self.execute_find_or_findstr(cmd, std, findstr=False)
def execute_findstr(self, cmd, std, *_)
Expand source code Browse git
@_command('FINDSTR')
def execute_findstr(self, cmd: SynCommand, std: IO, *_):
    return self.execute_find_or_findstr(cmd, std, findstr=True)
def execute_set(self, cmd, std, *_)
Expand source code Browse git
@_command('SET')
def execute_set(self, cmd: SynCommand, std: IO, *_):
    if not (args := cmd.args):
        raise EmulatorException('Empty SET instruction')

    if cmd.verb.upper() != 'SET':
        raise RuntimeError

    # Since variables can be used in GOTO, a SET can be used to change the behavior of a GOTO.
    self.block_labels.clear()

    arithmetic = False
    quote_mode = False
    prompt = None

    it = iter(args)
    tk = next(it)

    if tk.upper() == '/P':
        if std.i.closed:
            prompt = ''
        elif not (prompt := std.i.readline()).endswith('\n'):
            raise InputLocked
        else:
            prompt = prompt.rstrip('\r\n')
        tk = next(it)
    else:
        cmd.junk = not self.cfg.show_sets

    yield cmd

    if tk.upper() == '/A':
        arithmetic = True
        try:
            tk = next(it)
        except StopIteration:
            tk = ''

    args = [tk, *it, *cmd.trailing_spaces]

    if arithmetic:
        def defang(s: str):
            def r(m: re.Match[str]):
                return F'_{prefix}{ord(m[0]):X}_'
            return re.sub(r'[^-\s()!~*/%+><&^|_\w]', r, s)
        def refang(s: str): # noqa
            def r(m: re.Match[str]):
                return chr(int(m[1], 16))
            return re.sub(rf'_{prefix}([A-F0-9]+)_', r, s)
        prefix = F'{uuid.uuid4().time_mid:X}'
        namespace = {}
        translate = {}
        value = None
        if not (program := ''.join(args)):
            std.e.write('The syntax of the command is incorrect.\r\n')
            return ErrorZero.Val
        for assignment in program.split(','):
            assignment = assignment.strip()
            if not assignment:
                std.e.write('Missing operand.\r\n')
                return ErrorZero.Val
            name, operator, definition = re.split(r'([*+^|/%-&]|<<|>>|)=', assignment, maxsplit=1)
            name = name.upper()
            definition = re.sub(r'\b0([0-7]+)\b', r'0o\1', definition)
            if operator:
                definition = F'{name}{operator}({definition})'
            definition = defang(definition)
            expression = cautious_parse(definition)
            names = names_in_expression(expression)
            if names.stored or names.others:
                raise EmulatorException('Arithmetic SET had unexpected variable access.')
            for var in names.loaded:
                original = refang(name).upper()
                translate[original] = var
                if var in namespace:
                    continue
                try:
                    namespace[var] = batchint(self.environment[original])
                except (KeyError, ValueError):
                    namespace[var] = 0
            code = compile(expression, filename='[ast]', mode='eval')
            value = eval(code, namespace, {})
            self.environment[name] = str(value)
            namespace[defang(name)] = value
        if value is None:
            std.e.write('The syntax of the command is incorrect.')
            return
        else:
            std.o.write(F'{value!s}\r\n')
    else:
        try:
            eq = args.index(Ctrl.Equals)
        except ValueError:
            assignment = cmd.argument_string
            if assignment.startswith('"'):
                quote_mode = True
                assignment, _, unquoted = assignment[1:].rpartition('"')
                assignment = assignment or unquoted
            else:
                assignment = ''.join(args)
            name, _, content = assignment.partition('=')
        else:
            with StringIO() as io:
                for k in range(eq + 1, len(args)):
                    io.write(args[k])
                content = io.getvalue()
                name = cmd.args[eq - 1] if eq else ''
        name = name.upper()
        trailing_caret, content = uncaret(content, quote_mode)
        if trailing_caret:
            content = content[:-1]
        if prompt is not None:
            if (qc := content.strip()).startswith('"'):
                _, _, qc = qc. partition('"') # noqa
                qc, _, r = qc.rpartition('"') # noqa
                content = qc or r
            std.o.write(content)
            content = prompt
        if name:
            if content:
                self.environment[name] = content
            else:
                self.environment.pop(name, None)
def execute_call(self, cmd, std, *_)
Expand source code Browse git
@_command('CALL')
def execute_call(self, cmd: SynCommand, std: IO, *_):
    cmdl = cmd.argument_string
    empty, colon, label = cmdl.partition(':')
    if colon and not empty:
        try:
            offset = self.parser.lexer.labels[label.upper()]
        except KeyError as KE:
            raise InvalidLabel(label) from KE
        emu = self.spawn(self.parser, std=std)
    else:
        offset = 0
        path = cmdl.strip()
        code = self.state.ingest_file(path)
        if code is None:
            yield cmd
            return
        state = self.clone_state(environment=self.state.environment, filename=path)
        emu = self.spawn(code, std=std, state=state)
    if self.cfg.skip_call:
        emu.execute(called=True)
    else:
        yield from emu.trace(offset, called=True)
def execute_setlocal(self, cmd, *_)
Expand source code Browse git
@_command('SETLOCAL')
def execute_setlocal(self, cmd: SynCommand, *_):
    yield cmd
    setting = cmd.argument_string.strip().upper()
    delay = {
        'DISABLEDELAYEDEXPANSION': False,
        'ENABLEDELAYEDEXPANSION' : True,
    }.get(setting, self.state.delayexpand)
    cmdxt = {
        'DISABLEEXTENSIONS': False,
        'ENABLEEXTENSIONS' : True,
    }.get(setting, self.state.cmdextended)
    self.state.delayexpand_stack.append(delay)
    self.state.cmdextended_stack.append(cmdxt)
    self.state.environment_stack.append(dict(self.environment))
def execute_endlocal(self, cmd, *_)
Expand source code Browse git
@_command('ENDLOCAL')
def execute_endlocal(self, cmd: SynCommand, *_):
    yield cmd
    if len(self.state.environment_stack) > 1:
        self.state.environment_stack.pop()
        self.state.delayexpand_stack.pop()
def execute_goto(self, cmd, std, *_)
Expand source code Browse git
@_command('GOTO')
def execute_goto(self, cmd: SynCommand, std: IO, *_):
    if self.cfg.skip_goto:
        yield cmd
        return
    it = iter(cmd.args)
    mark = False
    for label in it:
        if not isinstance(label, Ctrl):
            break
        if label == Ctrl.Label:
            mark = True
            for label in it:
                break
            else:
                label = ''
            break
    else:
        std.e.write('No batch label specified to GOTO command.\r\n')
        raise AbortExecution
    label, *_ = label.split(maxsplit=1)
    key = label.upper()
    if mark and key == 'EOF':
        raise Exit(int(self.state.ec), False)
    if key not in self.block_labels:
        raise Goto(label)
    else:
        yield Error(F'Infinite Loop detected for label {key}')
def execute_exit(self, cmd, *_)
Expand source code Browse git
@_command('EXIT')
def execute_exit(self, cmd: SynCommand, *_):
    it = iter(cmd.args)
    exit = True
    token = 0
    for arg in it:
        if arg.upper() == '/B':
            exit = False
            continue
        token = arg
        break
    try:
        code = int(token)
    except ValueError:
        code = 0
    yield cmd
    if self.cfg.skip_exit:
        return
    raise Exit(code, exit)
def execute_chdir(self, cmd, *_)
Expand source code Browse git
@_command('CHDIR')
@_command('CD')
def execute_chdir(self, cmd: SynCommand, *_):
    yield cmd
    self.state.cwd = cmd.argument_string.strip()
def execute_pushd(self, cmd, *_)
Expand source code Browse git
@_command('PUSHD')
def execute_pushd(self, cmd: SynCommand, *_):
    yield cmd
    self.state.dirstack.append(self.state.cwd)
    self.execute_chdir(cmd)
def execute_popd(self, cmd, *_)
Expand source code Browse git
@_command('POPD')
def execute_popd(self, cmd: SynCommand, *_):
    yield cmd
    try:
        self.state.cwd = self.state.dirstack.pop()
    except IndexError:
        pass
def execute_echo(self, cmd, std, in_group)
Expand source code Browse git
@_command('ECHO')
def execute_echo(self, cmd: SynCommand, std: IO, in_group: bool):
    cmdl = cmd.argument_string
    mode = cmdl.strip().lower()
    current_state = self.state.echo
    if mode == 'on':
        if self.cfg.show_nops or current_state is False:
            yield cmd
        self.state.echo = True
        return
    if mode == 'off':
        if self.cfg.show_nops or current_state is True:
            yield cmd
        self.state.echo = False
        return
    yield cmd
    if mode:
        if in_group and not cmdl.endswith(' '):
            cmdl += ' '
        std.o.write(F'{cmdl}\r\n')
    else:
        mode = 'on' if self.state.echo else 'off'
        std.o.write(F'ECHO is {mode}.\r\n')
def execute_cls(self, cmd, *_)
Expand source code Browse git
@_command('CLS')
def execute_cls(self, cmd: SynCommand, *_):
    yield cmd
def execute_del(self, cmd, std, *_)
Expand source code Browse git
@_command('ERASE')
@_command('DEL')
def execute_del(self, cmd: SynCommand, std: IO, *_):
    if not cmd.args:
        yield Error('The syntax of the command is incorrect')
        return 1
    else:
        yield cmd
    flags = {}
    it = iter(cmd.args)
    while (arg := next(it)).startswith('/') and 1 < len(arg):
        flag = arg.upper()
        if flag[:3] == '/A:':
            flags['A'] = flag[3:]
            continue
        flags[flag[1]] = True
    _P = 'P' in flags # Prompts for confirmation before deleting each file.
    _F = 'F' in flags # Force deleting of read-only files.
    _S = 'S' in flags # Delete specified files from all subdirectories.
    _Q = 'Q' in flags # Quiet mode, do not ask if ok to delete on global wildcard
    paths = [arg, *it]
    state = self.state
    cwd = state.cwd
    for pattern in paths:
        for path in list(state.file_system):
            if not winfnmatch(pattern, path, cwd):
                continue
            if _F:
                pass
            if _S:
                pass
            if _Q:
                pass
            if _P and state.exists_file(pattern):
                std.o.write(F'{pattern}, Delete (Y/N)? ')
                decision = None
                while decision not in ('y', 'n'):
                    confirmation = std.i.readline()
                    if not confirmation.endswith('\n'):
                        raise InputLocked
                    decision = confirmation[:1].lower()
                if decision == 'n':
                    continue
            state.remove_file(path)
    return 0
def execute_start(self, cmd, std, *_)
Expand source code Browse git
@_command('START')
def execute_start(self, cmd: SynCommand, std: IO, *_):
    yield cmd
    it = iter(cmd.ast.fragments)
    it = itertools.islice(it, cmd.argument_offset, None)
    title = None
    start = None
    cwd = self.state.cwd
    env = ...
    for arg in it:
        if title is None:
            if '"' not in arg:
                title = ''
            else:
                title = unquote(arg)
                continue
        if arg.isspace():
            continue
        if not arg.startswith('/'):
            start = unquote(arg)
            break
        if (flag := arg.upper()) in ('/NODE', '/AFFINITY', '/MACHINE'):
            next(it)
        elif flag == '/D':
            cwd = next(it)
        elif flag == '/I':
            env = None
    if start and (batch := self.state.ingest_file(start)):
        state = self.clone_state(environment=env)
        state.cwd = cwd
        state.command_line = _fuse(it).strip()
        shell = self.spawn(batch, state, std)
        yield from shell.trace()
def execute_cmd(self, cmd, std, *_)
Expand source code Browse git
@_command('CMD')
def execute_cmd(self, cmd: SynCommand, std: IO, *_):
    yield cmd
    it = iter(cmd.ast.fragments)
    command = None
    quiet = False
    strip = False
    codec = 'cp1252'
    delayexpand = None
    cmdextended = None

    for arg in it:
        if arg.isspace() or not arg.startswith('/'):
            continue
        name, _, flag = arg[1:].partition(':')
        flag = flag.upper()
        name = name.upper()
        if name in 'CKR':
            command = _fuse(it)
            break
        elif name == 'Q':
            quiet = True
        elif name == 'S':
            strip = True
        elif name == 'U':
            codec = 'utf-16le'
        elif name == 'E':
            cmdextended = _onoff(flag)
        elif name == 'V':
            delayexpand = _onoff(flag)
    else:
        return 0

    if (stripped := re.search('^\\s*"(.*)"', command)) and (strip
        or command.count('"') != 2
        or re.search('[&<>()@^|]', stripped[1])
        or re.search('\\s', stripped[1]) is None
    ):
        command = stripped[1]

    state = self.clone_state(delayexpand=delayexpand, cmdextended=cmdextended)
    state.codec = codec
    state.echo = not quiet
    shell = self.spawn(command, state, std)
    yield from shell.trace()
def execute_unimplemented_program(self, cmd, *_)
Expand source code Browse git
@_command('ARP')
@_command('AT')
@_command('ATBROKER')
@_command('BGINFO')
@_command('BITSADMIN')
@_command('CERTUTIL')
@_command('CLIP')
@_command('CMSTP')
@_command('COMPACT')
@_command('CONTROL')
@_command('CSCRIPT')
@_command('CURL')
@_command('DEFRAG')
@_command('DISKSHADOW')
@_command('ESENTUTL')
@_command('EXPAND')
@_command('EXPLORER')
@_command('EXTRAC32')
@_command('FODHELPER')
@_command('FORFILES')
@_command('FTP')
@_command('HOSTNAME')
@_command('HOSTNAME')
@_command('INSTALLUTIL')
@_command('IPCONFIG')
@_command('LOGOFF')
@_command('MAKECAB')
@_command('MAVINJECT')
@_command('MOUNTVOL')
@_command('MSBUILD')
@_command('MSHTA')
@_command('MSIEXEC')
@_command('MSTSC')
@_command('NET')
@_command('NET1')
@_command('NETSH')
@_command('NSLOOKUP')
@_command('ODBCCONF')
@_command('PATHPING')
@_command('PING')
@_command('POWERSHELL')
@_command('PRESENTATIONHOST')
@_command('PWSH')
@_command('REG')
@_command('REGSVR32')
@_command('ROUTE')
@_command('RUNDLL32')
@_command('SCP')
@_command('SDCLT')
@_command('SETX')
@_command('SFTP')
@_command('SHUTDOWN')
@_command('SSH')
@_command('SUBST')
@_command('SYNCAPPVPUBLISHINGSERVER')
@_command('SYSTEMINFO')
@_command('TAR')
@_command('TELNET')
@_command('TFTP')
@_command('TIMEOUT')
@_command('TRACERT')
@_command('VSSADMIN')
@_command('WBADMIN')
@_command('WHERE')
@_command('WHOAMI')
@_command('WINRM')
@_command('WINRS')
@_command('WSCRIPT')
def execute_unimplemented_program(self, cmd: SynCommand, *_):
    yield cmd
    return 0
def execute_unimplemented_command_unmodified_ec(self, cmd, *_)
Expand source code Browse git
@_command('CLS')
def execute_unimplemented_command_unmodified_ec(self, cmd: SynCommand, *_):
    yield cmd
def execute_unimplemented_command(self, cmd, *_)
Expand source code Browse git
@_command('ASSOC')
@_command('ATTRIB')
@_command('BCDEDIT')
@_command('BREAK')
@_command('CACLS')
@_command('CHCP')
@_command('CHKDSK')
@_command('CHKNTFS')
@_command('COLOR')
@_command('COMP')
@_command('COMPACT')
@_command('CONVERT')
@_command('COPY')
@_command('DATE')
@_command('DIR')
@_command('DISKPART')
@_command('DOSKEY')
@_command('DRIVERQUERY')
@_command('FC')
@_command('FORMAT')
@_command('FSUTIL')
@_command('FTYPE')
@_command('GPRESULT')
@_command('ICACLS')
@_command('LABEL')
@_command('MD')
@_command('MKDIR')
@_command('MKLINK')
@_command('MODE')
@_command('MORE')
@_command('MOVE')
@_command('OPENFILES')
@_command('PATH')
@_command('PAUSE')
@_command('PRINT')
@_command('PROMPT')
@_command('RD')
@_command('RECOVER')
@_command('REN')
@_command('RENAME')
@_command('REPLACE')
@_command('RMDIR')
@_command('ROBOCOPY')
@_command('SC')
@_command('SCHTASKS')
@_command('SHIFT')
@_command('SHUTDOWN')
@_command('SORT')
@_command('SUBST')
@_command('SYSTEMINFO')
@_command('TASKKILL')
@_command('TASKLIST')
@_command('TIME')
@_command('TITLE')
@_command('TREE')
@_command('TYPE')
@_command('VER')
@_command('VERIFY')
@_command('VOL')
@_command('WMIC')
@_command('XCOPY')
def execute_unimplemented_command(self, cmd: SynCommand, *_):
    yield cmd
    return 0
def execute_rem(self, cmd, *_)
Expand source code Browse git
@_command('REM')
def execute_rem(self, cmd: SynCommand, *_):
    if self.cfg.show_comments:
        yield cmd
def execute_help(self, cmd, std, *_)
Expand source code Browse git
@_command('HELP')
def execute_help(self, cmd: SynCommand, std: IO, *_):
    yield cmd
    std.o.write(HelpOutput['HELP'])
    return 0
def execute_command(self, cmd, std, in_group)
Expand source code Browse git
def execute_command(self, cmd: SynCommand, std: IO, in_group: bool):
    verb = cmd.verb.upper().strip()
    handler = self._command.handlers.get(verb)

    if handler is None:
        base, ext = ntpath.splitext(verb)
        handler = None
        if any(ext == pe.upper() for pe in self.state.envar('PATHEXT', '').split(';')):
            handler = self._command.handlers.get(base)

    if handler is None:
        if self.state.exists_file(verb):
            self.state.ec = 0
        elif not indicators.winfpath.value.fullmatch(verb):
            if '\uFFFD' in verb or not verb.isprintable():
                self.state.ec = 9009
                cmd.junk = True
            else:
                cmd.junk = not self.cfg.show_junk
        yield cmd
        return

    paths: dict[int, str] = {}

    for src, r in cmd.ast.redirects.items():
        if not 0 <= src <= 2 or (src == 0) != r.is_input:
            continue
        if isinstance((target := r.target), str):
            if target.upper() == 'NUL':
                std[src] = DevNull()
            else:
                data = self.state.ingest_file(target)
                if src == 0:
                    if data is None:
                        yield ErrorCannotFindFile
                        return
                    std.i = StringIO(data)
                else:
                    if r.is_out_append:
                        buffer = StringIO(data)
                        buffer.seek(0, 2)
                    else:
                        buffer = StringIO()
                    std[src] = buffer
                    paths[src] = target
        elif src == 1 and target == 2:
            std.o = std.e
        elif src == 2 and target == 1:
            std.e = std.o

    if '/?' in cmd.args:
        std.o.write(HelpOutput[verb])
        self.state.ec = 0
        return

    if (result := handler(self, cmd, std, in_group)) is None:
        pass
    elif not isinstance(result, (int, ErrorZero)):
        result = (yield from result)

    for k, path in paths.items():
        self.state.create_file(path, std[k].getvalue())

    if result is not None:
        self.state.ec = result
def trace_pipeline(self, pipeline, std, in_group)
Expand source code Browse git
@_node(AstPipeline)
def trace_pipeline(self, pipeline: AstPipeline, std: IO, in_group: bool):
    length = len(pipeline.parts)
    streams = IO(*std)
    if length > 1:
        yield synthesize(pipeline)
    for k, part in enumerate(pipeline.parts, 1):
        if k != 1:
            streams.i = streams.o
            streams.i.seek(0)
        if k == length:
            streams.o = std.o
        else:
            streams.o = StringIO()
        if isinstance(part, AstGroup):
            it = self.trace_group(part, streams, in_group)
        else:
            ast = self.expand_ast_node(part)
            cmd = synthesize(ast)
            it = self.execute_command(cmd, streams, in_group)
        yield from it
def trace_sequence(self, sequence, std, in_group)
Expand source code Browse git
@_node(AstSequence)
def trace_sequence(self, sequence: AstSequence, std: IO, in_group: bool):
    yield from self.trace_statement(sequence.head, std, in_group)
    for cs in sequence.tail:
        if cs.condition == AstCondition.Failure:
            if bool(self.state.ec) is False:
                continue
        if cs.condition == AstCondition.Success:
            if bool(self.state.ec) is True:
                continue
        yield from self.trace_statement(cs.statement, std, in_group)
def trace_if(self, _if, std, in_group)
Expand source code Browse git
@_node(AstIf)
def trace_if(self, _if: AstIf, std: IO, in_group: bool):
    yield synthesize(_if)
    _if = self.expand_ast_node(_if)
    self.block_labels.clear()

    if _if.variant == AstIfVariant.ErrorLevel:
        condition = _if.var_int <= self.state.ec
    elif _if.variant == AstIfVariant.CmdExtVersion:
        condition = _if.var_int <= self.state.extensions_version
    elif _if.variant == AstIfVariant.Exist:
        condition = self.state.exists_file(_if.var_str)
    elif _if.variant == AstIfVariant.Defined:
        condition = _if.var_str.upper() in self.state.environment
    else:
        lhs = _if.lhs
        rhs = _if.rhs
        cmp = _if.cmp
        assert lhs is not None
        assert rhs is not None
        if cmp == AstIfCmp.STR:
            if _if.casefold:
                if isinstance(lhs, str):
                    lhs = lhs.casefold()
                if isinstance(rhs, str):
                    rhs = rhs.casefold()
            condition = lhs == rhs
        elif cmp == AstIfCmp.GTR:
            condition = lhs > rhs
        elif cmp == AstIfCmp.GEQ:
            condition = lhs >= rhs
        elif cmp == AstIfCmp.NEQ:
            condition = lhs != rhs
        elif cmp == AstIfCmp.EQU:
            condition = lhs == rhs
        elif cmp == AstIfCmp.LSS:
            condition = lhs < rhs
        elif cmp == AstIfCmp.LEQ:
            condition = lhs <= rhs
        else:
            raise RuntimeError(cmp)
    if _if.negated:
        condition = not condition

    if condition:
        yield from self.trace_sequence(_if.then_do, std, in_group)
    elif (_else := _if.else_do):
        yield from self.trace_sequence(_else, std, in_group)
def trace_for(self, _for, std, in_group)
Expand source code Browse git
@_node(AstFor)
def trace_for(self, _for: AstFor, std: IO, in_group: bool):
    state = self.state
    cwd = state.cwd
    vars = state.new_forloop()
    body = _for.body
    name = _for.variable
    vars[name] = ''

    if (
        self.contains_for_variable(body, vars)
            or _for.variant != AstForVariant.NumericLoop
            or len(_for.spec) != 1
    ):
        yield synthesize(_for)

    if _for.variant == AstForVariant.FileParsing:
        if _for.mode == AstForParserMode.Command:
            emulator = self.spawn(_for.specline, self.clone_state(filename=state.name))
            yield from emulator.trace()
            lines = emulator.std.o.getvalue().splitlines()
        elif _for.mode == AstForParserMode.Literal:
            lines = _for.spec
        else:
            def lines_from_files():
                fs = state.file_system
                for name in _for.spec:
                    for path, content in fs.items():
                        if not winfnmatch(path, name, cwd):
                            continue
                        yield from content.splitlines(False)
            lines = lines_from_files()
        opt = _for.options
        tokens = sorted(opt.tokens)
        split = re.compile('[{}]+'.format(re.escape(opt.delims)))
        count = tokens[-1] + 1 if tokens else 0
        first_variable = ord(name)
        if opt.asterisk:
            tokens.append(count)
        for n, line in enumerate(lines):
            if n < opt.skip:
                continue
            if opt.comment and line.startswith(opt.comment):
                continue
            if count:
                tokenized = split.split(line, maxsplit=count)
            else:
                tokenized = (line,)
            for k, tok in enumerate(tokens):
                name = chr(first_variable + k)
                if not name.isalpha():
                    raise EmulatorException('Ran out of variables in FOR-Loop.')
                try:
                    vars[name] = tokenized[tok]
                except IndexError:
                    vars[name] = ''
            yield from self.trace_sequence(body, std, in_group)
    else:
        for entry in _for.spec:
            vars[name] = entry
            yield from self.trace_sequence(body, std, in_group)
    state.end_forloop()
def trace_group(self, group, std, in_group)
Expand source code Browse git
@_node(AstGroup)
def trace_group(self, group: AstGroup, std: IO, in_group: bool):
    for sequence in group.fragments:
        yield from self.trace_sequence(sequence, std, True)
    yield synthesize(group)
def trace_label(self, label, *_)
Expand source code Browse git
@_node(AstLabel)
def trace_label(self, label: AstLabel, *_):
    if label.comment:
        if self.cfg.show_comments:
            yield synthesize(label)
    else:
        if self.cfg.show_labels:
            yield synthesize(label)
        self.block_labels.add(label.label.upper())
def trace_statement(self, statement, std, in_group)
Expand source code Browse git
def trace_statement(self, statement: AstStatement, std: IO, in_group: bool):
    try:
        handler = self._node.handlers[statement.__class__]
    except KeyError:
        raise RuntimeError(statement)
    yield from handler(self, statement, std, in_group)
def emulate_commands(self, allow_junk=False)
Expand source code Browse git
def emulate_commands(self, allow_junk=False):
    for syn in self.trace():
        if not isinstance(syn, SynCommand):
            continue
        if not allow_junk and syn.junk:
            continue
        yield str(syn)
def emulate_to_depth(self, depth=0)
Expand source code Browse git
def emulate_to_depth(self, depth: int = 0):
    for syn in self.trace():
        if not isinstance(syn, SynNodeBase):
            continue
        if syn.ast.depth <= depth:
            yield str(syn)
def emulate(self, offset=0)
Expand source code Browse git
def emulate(self, offset: int = 0):
    last: AstNode | None = None
    junk: AstNode | None = None
    for syn in self.trace(offset):
        if not isinstance(syn, SynNodeBase):
            continue
        ast = syn.ast
        if isinstance(syn, SynCommand) and syn.junk:
            junk = ast
            continue
        if junk is not None:
            if junk.is_descendant_of(ast):
                if not last or not last.is_descendant_of(ast):
                    continue
        if last is not None:
            if ast.is_descendant_of(last):
                # we already synthesized a parent construct, like a FOR loop or IF block
                continue
            if last.is_descendant_of(ast):
                # we synthesized a command and no longer need to synthesize an AST node that
                # wraps it, like a group
                continue
        if isinstance(ast, AstPipeline):
            if len(ast.parts) == 1:
                continue
        if last is ast:
            raise RuntimeError('Emulator attempted to synthesize the same command twice.')
        last = ast
        yield str(syn)
def execute(self, offset=0, called=False)
Expand source code Browse git
def execute(self, offset: int = 0, called: bool = False):
    for _ in self.trace(offset, called=called):
        pass
def trace(self, offset=0, called=False)
Expand source code Browse git
def trace(self, offset: int = 0, called: bool = False):
    if (name := self.state.name):
        self.state.create_file(name, self.parser.lexer.text)
    length = len(self.parser.lexer.code)
    labels = self.parser.lexer.labels

    while offset < length:
        try:
            for sequence in self.parser.parse(offset):
                if isinstance(sequence, AstError):
                    yield Error(sequence.error)
                    continue
                yield from self.trace_sequence(sequence, self.std, False)
        except Goto as goto:
            try:
                offset = labels[goto.label.upper()]
            except KeyError:
                raise InvalidLabel(goto.label) from goto
            continue
        except Exit as exit:
            self.state.ec = exit.code
            if exit.exit and called:
                raise
            else:
                break
        except AbortExecution:
            self.state.ec = 1
            break
        else:
            break
class BatchLexer (data, state=None)
Expand source code Browse git
class BatchLexer:

    labels: dict[str, int]
    code: memoryview

    pending_redirect: RedirectIO | None

    cursor: BatchLexerCursor
    resume: BatchLexerCursor | None

    class _register:
        # A handler is given the current mode and char. It returns a boolean indicating
        # whether or not the character was processed and may be consumed.
        handlers: ClassVar[dict[Mode, Callable[
            [BatchLexer, Mode, int], Generator[Token, None, bool]
        ]]] = {}

        def __init__(self, *modes: Mode):
            self.modes = modes

        def __call__(self, handler):
            for mode in self.modes:
                self.handlers[mode] = handler
            return handler

    def __init__(self, data: str | buf | BatchLexer, state: BatchState | None = None):
        if isinstance(data, BatchLexer):
            if state is not None:
                raise NotImplementedError
            self.text = data.text
            self.code = data.code
            self.labels = data.labels
            self.state = data.state
        else:
            if state is None:
                state = BatchState()
            self.state = state
            self.preparse(data)

    def parse_group(self):
        self.group += 1

    def parse_label(self):
        if (m := self.mode) != Mode.Text or len(self.modes) != 1:
            raise EmulatorException(F'Switching to LABEL while in mode {m.name}')
        self.mode_switch(Mode.Label)

    def parse_set(self):
        if (m := self.mode) != Mode.Text or len(self.modes) != 1:
            raise EmulatorException(F'Switching to SET while in mode {m.name}')
        self.mode_switch(Mode.SetStarted)

    @property
    def environment(self):
        return self.state.environment

    def parse_arg_variable(self, var: ArgVar):
        """
        %* in a batch script refers to all the arguments (e.g. %1 %2 %3
            %4 %5 ...)
        Substitution of batch parameters (%n) has been enhanced.  You can
        now use the following optional syntax:
            %~1         - expands %1 removing any surrounding quotes (")
            %~f1        - expands %1 to a fully qualified path name
            %~d1        - expands %1 to a drive letter only
            %~p1        - expands %1 to a path only
            %~n1        - expands %1 to a file name only
            %~x1        - expands %1 to a file extension only
            %~s1        - expanded path contains short names only
            %~a1        - expands %1 to file attributes
            %~t1        - expands %1 to date/time of file
            %~z1        - expands %1 to size of file
            %~$PATH:1   - searches the directories listed in the PATH
                           environment variable and expands %1 to the fully
                           qualified name of the first one found.  If the
                           environment variable name is not defined or the
                           file is not found by the search, then this
                           modifier expands to the empty string
        The modifiers can be combined to get compound results:
            %~dp1       - expands %1 to a drive letter and path only
            %~nx1       - expands %1 to a file name and extension only
            %~dp$PATH:1 - searches the directories listed in the PATH
                           environment variable for %1 and expands to the
                           drive letter and path of the first one found.
            %~ftza1     - expands %1 to a DIR like output line
        In the above examples %1 and PATH can be replaced by other
        valid values.  The %~ syntax is terminated by a valid argument
        number.  The %~ modifiers may not be used with %*
        """
        state = self.state
        flags = var.flags

        if (k := var.offset) is (...):
            return state.command_line
        if (j := k - 1) < 0:
            argval = state.name
        elif j < len(args := state.args):
            argval = args[j]
        else:
            return ''

        if not flags:
            return argval

        has_path = 0 != ArgVarFlags.FullPath & flags

        if flags.StripQuotes and argval.startswith('"') and argval.endswith('"'):
            argval = argval[1:-1]

        if flags.ShortName and not has_path:
            flags |= ArgVarFlags.FullPath
            has_path = True

        with io.StringIO() as out:
            if flags.Attributes:
                out.write('--a--------') # TODO: placeholder
            if flags.DateTime:
                dt = state.start_time.isoformat(' ', 'minutes')
                out.write(F' {dt}')
            if flags.FileSize:
                out.write(F' {state.sizeof_file(argval)}')
            if has_path:
                out.write(' ')
                full_path = state.resolve_path(argval)
                drv, rest = ntpath.splitdrive(full_path)
                *pp, name = ntpath.split(rest)
                name, ext = ntpath.splitext(name)
                if flags.DriveLetter:
                    out.write(drv)
                if flags.FilePath:
                    out.write(ntpath.join(*pp))
                if flags.FileName:
                    out.write(name)
                if flags.FileExtension:
                    out.write(ext)
            return out.getvalue().lstrip()

    @property
    def modes(self):
        return self.cursor.modes

    def reset(self, offset: int):
        self.quote = False
        self.caret = False
        self.white = False
        self.first_after_gap = True
        self.group = 0
        self.cursor = BatchLexerCursor(offset)
        self.modes.append(Mode.Text)
        self.resume = None
        self.pending_redirect = None

    def mode_reset(self):
        del self.modes[1:]

    def mode_finish(self):
        modes = self.modes
        if len(modes) <= 1:
            raise RuntimeError('Trying to exit base mode.')
        self.modes.pop()

    def mode_switch(self, mode: Mode):
        self.modes.append(mode)

    @property
    def mode(self):
        return self.modes[-1]

    @mode.setter
    def mode(self, value: Mode):
        self.modes[-1] = value

    @property
    def substituting(self):
        return self.cursor.substituting

    @property
    def eof(self):
        return (c := self.cursor).offset >= len(self.code) and not c.subst_buffer

    def quick_save(self):
        self.resume = self.cursor.copy()

    def quick_load(self):
        if (resume := self.resume) is None:
            raise RuntimeError
        self.cursor = resume
        self.resume = None

    def current_char(self):
        cursor = self.cursor
        if not (subst := cursor.subst_buffer):
            offset = cursor.offset
            if self.code[offset] == PERCENT:
                cursor.offset += 1
                self.fill_substitution_buffer()
                return self.current_char()
        else:
            offset = cursor.subst_offset
            if offset >= (n := len(subst)):
                offset -= n
                offset += cursor.offset
            else:
                return subst[offset]
        try:
            return self.code[offset]
        except IndexError:
            raise UnexpectedEOF

    def consume_char(self):
        cursor = self.cursor
        if subst := cursor.subst_buffer:
            offset = cursor.subst_offset + 1
            if offset >= len(subst):
                del subst[:]
                cursor.subst_offset = -1
            else:
                cursor.subst_offset = offset
        else:
            offset = cursor.offset + 1
            if offset > len(self.code):
                raise EOFError('Consumed a character beyond EOF.')
            cursor.offset = offset

    def next_char(self):
        self.consume_char()
        return self.current_char()

    def parse_env_variable(self, var: str):
        if var == '':
            return '%'
        name, _, modifier = var.partition(':')
        base = self.state.envar(name)
        if not modifier or not base:
            return base
        if '=' in modifier:
            old, _, new = modifier.partition('=')
            kwargs = {}
            if old.startswith('~'):
                old = old[1:]
                kwargs.update(count=1)
            return base.replace(old, new, **kwargs)
        else:
            if not modifier.startswith('~'):
                raise EmulatorException
            offset, _, length = modifier[1:].partition(',')
            offset = batchint(offset)
            if offset < 0:
                offset = max(0, len(base) + offset)
            if length:
                end = offset + batchint(length)
            else:
                end = len(base)
            return base[offset:end]

    def emit_token(self):
        switched = False
        if (buffer := self.cursor.token) and (token := u16(buffer)):
            if (pr := self.pending_redirect):
                pr.target = unquote(token)
                self.pending_redirect = None
                self.mode_switch(Mode.Gap)
                yield pr
                switched = True
            else:
                yield Word(token)
        del buffer[:]
        self.first_after_gap = False
        return switched

    def tokens(self, offset: int) -> Generator[Token]:
        self.reset(offset)
        handlers = self._register.handlers
        current_char = self.current_char
        consume_char = self.consume_char
        size = len(self.code)

        while not self.cursor.eof(size):
            c = current_char()
            m = self.mode
            h = handlers[m]
            if (yield from h(self, m, c)):
                consume_char()

        if not self.first_after_gap:
            yield from self.emit_token()

    def check_line_break(self, mode: Mode, char: int):
        if char != LINEBREAK:
            return False
        if not self.caret:
            # caret is not reset until the next char!
            yield from self.emit_token()
            self.white = True
            self.quote = False
            self.mode_reset()
            yield Ctrl.NewLine
        self.consume_char()
        return True

    def check_caret(self, char: int):
        if self.caret:
            self.cursor.token.append(char)
            self.caret = False
            self.consume_char()
            return True
        elif char == CARET:
            self.caret = True
            self.consume_char()
            return True
        else:
            return False

    def check_command_separators(self, char: int):
        if char == PAREN_CLOSE and (g := self.group) > 0:
            yield from self.emit_token()
            yield Ctrl.EndGroup
            self.mode_reset()
            self.consume_char()
            self.group = g - 1
            return True
        elif char == AMPERSAND:
            tok = Ctrl.Ampersand
        elif char == PIPE:
            tok = Ctrl.Pipe
        else:
            return False
        yield from self.emit_token()
        self.mode_reset()
        self.consume_char()
        yield tok
        return True

    def check_quote_start(self, char: int):
        if char != QUOTE:
            return False
        self.cursor.token.append(char)
        self.mode_switch(Mode.Quote)
        self.caret = False
        self.consume_char()
        return True

    def check_redirect_io(self, char: int):
        if char not in ANGLES:
            return False

        output = char != ANGLE_OPEN
        token = self.cursor.token

        if len(token) == 1 and (src := token[0] - ZERO) in range(10):
            del token[:]
            source = src
        else:
            source = int(output)

        char = self.next_char()

        if not output:
            how = Redirect.In
        elif char == ANGLE_CLOSE:
            how = Redirect.OutAppend
            char = self.next_char()
        else:
            how = Redirect.OutCreate

        yield from self.emit_token()

        if char != AMPERSAND:
            self.pending_redirect = RedirectIO(how, source)
            self.mode_switch(Mode.Gap)
        else:
            char = self.next_char()
            if char not in range(ZERO, NINE + 1):
                raise UnexpectedToken(char)
            self.consume_char()
            yield RedirectIO(how, source, char - ZERO)

        return True

    def fill_substitution_buffer(self):
        if (cursor := self.cursor).substituting:
            return

        code = self.code
        var_resume = -1
        var_dollar = -1
        var_cmdarg = ArgVar()
        variable = None
        phase = EV.New
        q = ArgVarFlags.StripQuotes

        for current in range((current := cursor.offset), len(code)):
            char = code[current]
            if char == LINEBREAK:
                break
            elif char == PERCENT:
                try:
                    var_name = u16(self.code[cursor.offset:current])
                    variable = u16(self.parse_env_variable(var_name))
                except MissingVariable:
                    if var_resume < 0:
                        var_resume = current + 1
                    break
            elif var_cmdarg:
                if ZERO <= char <= NINE:
                    var_cmdarg.offset = char - ZERO
                    variable = u16(self.parse_arg_variable(var_cmdarg))
                elif char == ASTERIX and cursor.offset == current:
                    var_cmdarg.offset = (...)
                    variable = u16(self.parse_arg_variable(var_cmdarg))
            if variable is not None:
                cursor.subst_offset = 0
                cursor.subst_buffer.extend(variable)
                var_resume = current + 1
                break
            if phase == EV.Mod:
                var_cmdarg = None
            elif phase == EV.Env:
                if char == COLON:
                    if var_cmdarg:
                        assert var_dollar > 0
                        var_cmdarg.path = u16(self.code[var_dollar:current])
                    var_resume = current + 1
            else:
                if char == DOLLAR:
                    var_dollar = current + 1
                    phase = EV.Env
                    continue
                if char == COLON:
                    var_cmdarg = None
                    var_resume = current + 1
                    phase = EV.Mod
                    continue
                if not var_cmdarg:
                    continue
                try:
                    var_cmdarg.flags |= ArgVarFlags.FromToken(char)
                except KeyError:
                    var_cmdarg = None
                    continue
                if q not in var_cmdarg.flags:
                    var_cmdarg = None
        if var_resume >= 0:
            cursor.offset = var_resume

    @_register(Mode.Label)
    def gobble_label(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        if (yield from self.check_line_break(mode, char)):
            return False
        self.cursor.token.append(char)
        return True

    @_register(Mode.Quote)
    def gobble_quote(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        if (yield from self.check_line_break(mode, char)):
            return False
        self.cursor.token.append(char)
        if char == QUOTE:
            self.mode_finish()
        return True

    @_register(Mode.Whitespace)
    def gobble_whitespace(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        if char in WHITESPACE:
            self.cursor.token.append(char)
            return True
        self.mode_finish()
        token = self.cursor.token
        yield Word(u16(token))
        del token[:]
        return False

    @_register(Mode.SetQuoted)
    def gobble_quoted_set(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        if char == QUOTE:
            self.consume_char()
            self.cursor.token.append(QUOTE)
            self.quick_save()
            return False

        if char == LINEBREAK:
            if self.resume is None:
                yield from self.emit_token()
                self.mode_reset()
                yield Ctrl.NewLine
                return True
            elif self.caret:
                self.caret = False
                return True
            else:
                self.quick_load()
                yield from self.emit_token()
                return False

        if self.resume is not None:
            if char == CARET:
                self.caret = not self.caret
            elif not self.caret:
                if (char == PAREN_CLOSE and self.group > 0) or char in (PIPE, AMPERSAND):
                    self.quick_load()
                    yield from self.emit_token()
                    self.mode_finish()
                    # after a quick load, the ending quote was already consumed.
                    return False

        self.cursor.token.append(char)
        return True

    @_register(Mode.SetStarted)
    def gobble_set(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        token = self.cursor.token
        if (yield from self.check_line_break(mode, char)):
            return False
        if char in WHITESPACE:
            yield from self.emit_token()
            token.append(char)
            self.mode_switch(Mode.Whitespace)
            return True
        if char == SLASH and not self.pending_redirect:
            yield from self.emit_token()
            token.append(char)
            return True
        if not token and char == QUOTE:
            self.caret = False
            token.append(char)
            self.mode = Mode.SetQuoted
            return True
        if self.check_caret(char):
            return False
        if char == EQUALS:
            yield from self.emit_token()
            yield Ctrl.Equals
            self.mode = Mode.SetRegular
            return True
        if self.check_quote_start(char):
            return False
        if (yield from self.check_command_separators(char)):
            return False
        if (yield from self.check_redirect_io(char)):
            return False
        token.append(char)
        return True

    def common_token_checks(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        return (False
            or (yield from self.check_line_break(mode, char))
            or self.check_caret(char)
            or self.check_quote_start(char)
            or (yield from self.check_command_separators(char))
            or (yield from self.check_redirect_io(char)))

    @_register(Mode.SetRegular)
    def gobble_set_regular(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        if (yield from self.common_token_checks(mode, char)):
            return False
        if (pr := self.pending_redirect) and char in WHITESPACE:
            token = self.cursor.token
            self.pending_redirect = None
            pr.target = unquote(u16(token))
            del token[:]
            yield pr
        self.cursor.token.append(char)
        return True

    @_register(Mode.Gap)
    def gobble_gap(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        yield from ()
        if char in SEPARATORS:
            return True
        self.mode_finish()
        self.first_after_gap = True
        return False

    @_register(Mode.Text)
    def gobble_txt(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
        if (yield from self.common_token_checks(mode, char)):
            return False
        if char in WHITESPACE:
            yield from self.emit_token()
            self.cursor.token.append(char)
            self.mode_switch(Mode.Whitespace)
            return True
        if char == SLASH and not self.pending_redirect:
            yield from self.emit_token()
        if char == COLON:
            if (yield from self.emit_token()):
                return False
            elif self.next_char() == COLON:
                yield Ctrl.Comment
                return True
            else:
                yield Ctrl.Label
                return False
        try:
            token = SeparatorMap[char]
        except KeyError:
            pass
        else:
            if (yield from self.emit_token()):
                return False
            else:
                yield token
                return True
        self.cursor.token.append(char)
        return True

    @staticmethod
    def label(text: str, uppercase=True):
        parts = re.split('([\x20\t\v])', text.lstrip())
        for k, part in itertools.islice(enumerate(parts), 0, None, 2):
            tq, part = uncaret(part, True)
            if not tq:
                parts[k] = part
                del parts[k + 1:]
                break
            parts[k] = part[:-1]
        label = ''.join(parts)
        if uppercase:
            label = label.upper()
        return label

    def preparse(self, text: str | buf):
        self.labels = {}

        if not isinstance(text, str):
            text = codecs.decode(text, 'utf8', errors='replace')

        _tail = text[-10:]
        lines = text.splitlines(keepends=False)
        utf16 = array.array('H')

        if _tail.splitlines() != F'{_tail}\n'.splitlines():
            # the text had a trailing line break, which is swallowed by the splitlines method
            lines.append('')

        for k, line in enumerate(lines):
            if k > 0:
                utf16.append(LINEBREAK)
            encoded = line.encode('utf-16le')
            if not encoded:
                continue
            encoded = memoryview(encoded).cast('H')
            offset = len(utf16)
            prefix = re.search('^@?[\\s]*:', line)
            if prefix:
                p = prefix.end()
                if lb := self.label(u16(encoded[p:])):
                    self.labels.setdefault(lb, offset + p - 1)
            utf16.extend(encoded)

        self.text = text
        self.code = memoryview(utf16)

    if set(_register.handlers) != set(Mode):
        raise NotImplementedError('Not all handlers were implemented.')

Class variables

var labels

The type of the None singleton.

var code

The type of the None singleton.

var pending_redirect

The type of the None singleton.

var cursor

The type of the None singleton.

var resume

The type of the None singleton.

Static methods

def label(text, uppercase=True)
Expand source code Browse git
@staticmethod
def label(text: str, uppercase=True):
    parts = re.split('([\x20\t\v])', text.lstrip())
    for k, part in itertools.islice(enumerate(parts), 0, None, 2):
        tq, part = uncaret(part, True)
        if not tq:
            parts[k] = part
            del parts[k + 1:]
            break
        parts[k] = part[:-1]
    label = ''.join(parts)
    if uppercase:
        label = label.upper()
    return label

Instance variables

var environment
Expand source code Browse git
@property
def environment(self):
    return self.state.environment
var modes
Expand source code Browse git
@property
def modes(self):
    return self.cursor.modes
var mode
Expand source code Browse git
@property
def mode(self):
    return self.modes[-1]
var substituting
Expand source code Browse git
@property
def substituting(self):
    return self.cursor.substituting
var eof
Expand source code Browse git
@property
def eof(self):
    return (c := self.cursor).offset >= len(self.code) and not c.subst_buffer

Methods

def parse_group(self)
Expand source code Browse git
def parse_group(self):
    self.group += 1
def parse_label(self)
Expand source code Browse git
def parse_label(self):
    if (m := self.mode) != Mode.Text or len(self.modes) != 1:
        raise EmulatorException(F'Switching to LABEL while in mode {m.name}')
    self.mode_switch(Mode.Label)
def parse_set(self)
Expand source code Browse git
def parse_set(self):
    if (m := self.mode) != Mode.Text or len(self.modes) != 1:
        raise EmulatorException(F'Switching to SET while in mode {m.name}')
    self.mode_switch(Mode.SetStarted)
def parse_arg_variable(self, var)

% in a batch script refers to all the arguments (e.g. %1 %2 %3 %4 %5 …) Substitution of batch parameters (%n) has been enhanced. You can now use the following optional syntax: %~1 - expands %1 removing any surrounding quotes (") %~f1 - expands %1 to a fully qualified path name %~d1 - expands %1 to a drive letter only %~p1 - expands %1 to a path only %~n1 - expands %1 to a file name only %~x1 - expands %1 to a file extension only %~s1 - expanded path contains short names only %~a1 - expands %1 to file attributes %~t1 - expands %1 to date/time of file %~z1 - expands %1 to size of file %~$PATH:1 - searches the directories listed in the PATH environment variable and expands %1 to the fully qualified name of the first one found. If the environment variable name is not defined or the file is not found by the search, then this modifier expands to the empty string The modifiers can be combined to get compound results: %~dp1 - expands %1 to a drive letter and path only %~nx1 - expands %1 to a file name and extension only %~dp$PATH:1 - searches the directories listed in the PATH environment variable for %1 and expands to the drive letter and path of the first one found. %~ftza1 - expands %1 to a DIR like output line In the above examples %1 and PATH can be replaced by other valid values. The %~ syntax is terminated by a valid argument number. The %~ modifiers may not be used with %

Expand source code Browse git
def parse_arg_variable(self, var: ArgVar):
    """
    %* in a batch script refers to all the arguments (e.g. %1 %2 %3
        %4 %5 ...)
    Substitution of batch parameters (%n) has been enhanced.  You can
    now use the following optional syntax:
        %~1         - expands %1 removing any surrounding quotes (")
        %~f1        - expands %1 to a fully qualified path name
        %~d1        - expands %1 to a drive letter only
        %~p1        - expands %1 to a path only
        %~n1        - expands %1 to a file name only
        %~x1        - expands %1 to a file extension only
        %~s1        - expanded path contains short names only
        %~a1        - expands %1 to file attributes
        %~t1        - expands %1 to date/time of file
        %~z1        - expands %1 to size of file
        %~$PATH:1   - searches the directories listed in the PATH
                       environment variable and expands %1 to the fully
                       qualified name of the first one found.  If the
                       environment variable name is not defined or the
                       file is not found by the search, then this
                       modifier expands to the empty string
    The modifiers can be combined to get compound results:
        %~dp1       - expands %1 to a drive letter and path only
        %~nx1       - expands %1 to a file name and extension only
        %~dp$PATH:1 - searches the directories listed in the PATH
                       environment variable for %1 and expands to the
                       drive letter and path of the first one found.
        %~ftza1     - expands %1 to a DIR like output line
    In the above examples %1 and PATH can be replaced by other
    valid values.  The %~ syntax is terminated by a valid argument
    number.  The %~ modifiers may not be used with %*
    """
    state = self.state
    flags = var.flags

    if (k := var.offset) is (...):
        return state.command_line
    if (j := k - 1) < 0:
        argval = state.name
    elif j < len(args := state.args):
        argval = args[j]
    else:
        return ''

    if not flags:
        return argval

    has_path = 0 != ArgVarFlags.FullPath & flags

    if flags.StripQuotes and argval.startswith('"') and argval.endswith('"'):
        argval = argval[1:-1]

    if flags.ShortName and not has_path:
        flags |= ArgVarFlags.FullPath
        has_path = True

    with io.StringIO() as out:
        if flags.Attributes:
            out.write('--a--------') # TODO: placeholder
        if flags.DateTime:
            dt = state.start_time.isoformat(' ', 'minutes')
            out.write(F' {dt}')
        if flags.FileSize:
            out.write(F' {state.sizeof_file(argval)}')
        if has_path:
            out.write(' ')
            full_path = state.resolve_path(argval)
            drv, rest = ntpath.splitdrive(full_path)
            *pp, name = ntpath.split(rest)
            name, ext = ntpath.splitext(name)
            if flags.DriveLetter:
                out.write(drv)
            if flags.FilePath:
                out.write(ntpath.join(*pp))
            if flags.FileName:
                out.write(name)
            if flags.FileExtension:
                out.write(ext)
        return out.getvalue().lstrip()
def reset(self, offset)
Expand source code Browse git
def reset(self, offset: int):
    self.quote = False
    self.caret = False
    self.white = False
    self.first_after_gap = True
    self.group = 0
    self.cursor = BatchLexerCursor(offset)
    self.modes.append(Mode.Text)
    self.resume = None
    self.pending_redirect = None
def mode_reset(self)
Expand source code Browse git
def mode_reset(self):
    del self.modes[1:]
def mode_finish(self)
Expand source code Browse git
def mode_finish(self):
    modes = self.modes
    if len(modes) <= 1:
        raise RuntimeError('Trying to exit base mode.')
    self.modes.pop()
def mode_switch(self, mode)
Expand source code Browse git
def mode_switch(self, mode: Mode):
    self.modes.append(mode)
def quick_save(self)
Expand source code Browse git
def quick_save(self):
    self.resume = self.cursor.copy()
def quick_load(self)
Expand source code Browse git
def quick_load(self):
    if (resume := self.resume) is None:
        raise RuntimeError
    self.cursor = resume
    self.resume = None
def current_char(self)
Expand source code Browse git
def current_char(self):
    cursor = self.cursor
    if not (subst := cursor.subst_buffer):
        offset = cursor.offset
        if self.code[offset] == PERCENT:
            cursor.offset += 1
            self.fill_substitution_buffer()
            return self.current_char()
    else:
        offset = cursor.subst_offset
        if offset >= (n := len(subst)):
            offset -= n
            offset += cursor.offset
        else:
            return subst[offset]
    try:
        return self.code[offset]
    except IndexError:
        raise UnexpectedEOF
def consume_char(self)
Expand source code Browse git
def consume_char(self):
    cursor = self.cursor
    if subst := cursor.subst_buffer:
        offset = cursor.subst_offset + 1
        if offset >= len(subst):
            del subst[:]
            cursor.subst_offset = -1
        else:
            cursor.subst_offset = offset
    else:
        offset = cursor.offset + 1
        if offset > len(self.code):
            raise EOFError('Consumed a character beyond EOF.')
        cursor.offset = offset
def next_char(self)
Expand source code Browse git
def next_char(self):
    self.consume_char()
    return self.current_char()
def parse_env_variable(self, var)
Expand source code Browse git
def parse_env_variable(self, var: str):
    if var == '':
        return '%'
    name, _, modifier = var.partition(':')
    base = self.state.envar(name)
    if not modifier or not base:
        return base
    if '=' in modifier:
        old, _, new = modifier.partition('=')
        kwargs = {}
        if old.startswith('~'):
            old = old[1:]
            kwargs.update(count=1)
        return base.replace(old, new, **kwargs)
    else:
        if not modifier.startswith('~'):
            raise EmulatorException
        offset, _, length = modifier[1:].partition(',')
        offset = batchint(offset)
        if offset < 0:
            offset = max(0, len(base) + offset)
        if length:
            end = offset + batchint(length)
        else:
            end = len(base)
        return base[offset:end]
def emit_token(self)
Expand source code Browse git
def emit_token(self):
    switched = False
    if (buffer := self.cursor.token) and (token := u16(buffer)):
        if (pr := self.pending_redirect):
            pr.target = unquote(token)
            self.pending_redirect = None
            self.mode_switch(Mode.Gap)
            yield pr
            switched = True
        else:
            yield Word(token)
    del buffer[:]
    self.first_after_gap = False
    return switched
def tokens(self, offset)
Expand source code Browse git
def tokens(self, offset: int) -> Generator[Token]:
    self.reset(offset)
    handlers = self._register.handlers
    current_char = self.current_char
    consume_char = self.consume_char
    size = len(self.code)

    while not self.cursor.eof(size):
        c = current_char()
        m = self.mode
        h = handlers[m]
        if (yield from h(self, m, c)):
            consume_char()

    if not self.first_after_gap:
        yield from self.emit_token()
def check_line_break(self, mode, char)
Expand source code Browse git
def check_line_break(self, mode: Mode, char: int):
    if char != LINEBREAK:
        return False
    if not self.caret:
        # caret is not reset until the next char!
        yield from self.emit_token()
        self.white = True
        self.quote = False
        self.mode_reset()
        yield Ctrl.NewLine
    self.consume_char()
    return True
def check_caret(self, char)
Expand source code Browse git
def check_caret(self, char: int):
    if self.caret:
        self.cursor.token.append(char)
        self.caret = False
        self.consume_char()
        return True
    elif char == CARET:
        self.caret = True
        self.consume_char()
        return True
    else:
        return False
def check_command_separators(self, char)
Expand source code Browse git
def check_command_separators(self, char: int):
    if char == PAREN_CLOSE and (g := self.group) > 0:
        yield from self.emit_token()
        yield Ctrl.EndGroup
        self.mode_reset()
        self.consume_char()
        self.group = g - 1
        return True
    elif char == AMPERSAND:
        tok = Ctrl.Ampersand
    elif char == PIPE:
        tok = Ctrl.Pipe
    else:
        return False
    yield from self.emit_token()
    self.mode_reset()
    self.consume_char()
    yield tok
    return True
def check_quote_start(self, char)
Expand source code Browse git
def check_quote_start(self, char: int):
    if char != QUOTE:
        return False
    self.cursor.token.append(char)
    self.mode_switch(Mode.Quote)
    self.caret = False
    self.consume_char()
    return True
def check_redirect_io(self, char)
Expand source code Browse git
def check_redirect_io(self, char: int):
    if char not in ANGLES:
        return False

    output = char != ANGLE_OPEN
    token = self.cursor.token

    if len(token) == 1 and (src := token[0] - ZERO) in range(10):
        del token[:]
        source = src
    else:
        source = int(output)

    char = self.next_char()

    if not output:
        how = Redirect.In
    elif char == ANGLE_CLOSE:
        how = Redirect.OutAppend
        char = self.next_char()
    else:
        how = Redirect.OutCreate

    yield from self.emit_token()

    if char != AMPERSAND:
        self.pending_redirect = RedirectIO(how, source)
        self.mode_switch(Mode.Gap)
    else:
        char = self.next_char()
        if char not in range(ZERO, NINE + 1):
            raise UnexpectedToken(char)
        self.consume_char()
        yield RedirectIO(how, source, char - ZERO)

    return True
def fill_substitution_buffer(self)
Expand source code Browse git
def fill_substitution_buffer(self):
    if (cursor := self.cursor).substituting:
        return

    code = self.code
    var_resume = -1
    var_dollar = -1
    var_cmdarg = ArgVar()
    variable = None
    phase = EV.New
    q = ArgVarFlags.StripQuotes

    for current in range((current := cursor.offset), len(code)):
        char = code[current]
        if char == LINEBREAK:
            break
        elif char == PERCENT:
            try:
                var_name = u16(self.code[cursor.offset:current])
                variable = u16(self.parse_env_variable(var_name))
            except MissingVariable:
                if var_resume < 0:
                    var_resume = current + 1
                break
        elif var_cmdarg:
            if ZERO <= char <= NINE:
                var_cmdarg.offset = char - ZERO
                variable = u16(self.parse_arg_variable(var_cmdarg))
            elif char == ASTERIX and cursor.offset == current:
                var_cmdarg.offset = (...)
                variable = u16(self.parse_arg_variable(var_cmdarg))
        if variable is not None:
            cursor.subst_offset = 0
            cursor.subst_buffer.extend(variable)
            var_resume = current + 1
            break
        if phase == EV.Mod:
            var_cmdarg = None
        elif phase == EV.Env:
            if char == COLON:
                if var_cmdarg:
                    assert var_dollar > 0
                    var_cmdarg.path = u16(self.code[var_dollar:current])
                var_resume = current + 1
        else:
            if char == DOLLAR:
                var_dollar = current + 1
                phase = EV.Env
                continue
            if char == COLON:
                var_cmdarg = None
                var_resume = current + 1
                phase = EV.Mod
                continue
            if not var_cmdarg:
                continue
            try:
                var_cmdarg.flags |= ArgVarFlags.FromToken(char)
            except KeyError:
                var_cmdarg = None
                continue
            if q not in var_cmdarg.flags:
                var_cmdarg = None
    if var_resume >= 0:
        cursor.offset = var_resume
def gobble_label(self, mode, char)
Expand source code Browse git
@_register(Mode.Label)
def gobble_label(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    if (yield from self.check_line_break(mode, char)):
        return False
    self.cursor.token.append(char)
    return True
def gobble_quote(self, mode, char)
Expand source code Browse git
@_register(Mode.Quote)
def gobble_quote(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    if (yield from self.check_line_break(mode, char)):
        return False
    self.cursor.token.append(char)
    if char == QUOTE:
        self.mode_finish()
    return True
def gobble_whitespace(self, mode, char)
Expand source code Browse git
@_register(Mode.Whitespace)
def gobble_whitespace(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    if char in WHITESPACE:
        self.cursor.token.append(char)
        return True
    self.mode_finish()
    token = self.cursor.token
    yield Word(u16(token))
    del token[:]
    return False
def gobble_quoted_set(self, mode, char)
Expand source code Browse git
@_register(Mode.SetQuoted)
def gobble_quoted_set(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    if char == QUOTE:
        self.consume_char()
        self.cursor.token.append(QUOTE)
        self.quick_save()
        return False

    if char == LINEBREAK:
        if self.resume is None:
            yield from self.emit_token()
            self.mode_reset()
            yield Ctrl.NewLine
            return True
        elif self.caret:
            self.caret = False
            return True
        else:
            self.quick_load()
            yield from self.emit_token()
            return False

    if self.resume is not None:
        if char == CARET:
            self.caret = not self.caret
        elif not self.caret:
            if (char == PAREN_CLOSE and self.group > 0) or char in (PIPE, AMPERSAND):
                self.quick_load()
                yield from self.emit_token()
                self.mode_finish()
                # after a quick load, the ending quote was already consumed.
                return False

    self.cursor.token.append(char)
    return True
def gobble_set(self, mode, char)
Expand source code Browse git
@_register(Mode.SetStarted)
def gobble_set(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    token = self.cursor.token
    if (yield from self.check_line_break(mode, char)):
        return False
    if char in WHITESPACE:
        yield from self.emit_token()
        token.append(char)
        self.mode_switch(Mode.Whitespace)
        return True
    if char == SLASH and not self.pending_redirect:
        yield from self.emit_token()
        token.append(char)
        return True
    if not token and char == QUOTE:
        self.caret = False
        token.append(char)
        self.mode = Mode.SetQuoted
        return True
    if self.check_caret(char):
        return False
    if char == EQUALS:
        yield from self.emit_token()
        yield Ctrl.Equals
        self.mode = Mode.SetRegular
        return True
    if self.check_quote_start(char):
        return False
    if (yield from self.check_command_separators(char)):
        return False
    if (yield from self.check_redirect_io(char)):
        return False
    token.append(char)
    return True
def common_token_checks(self, mode, char)
Expand source code Browse git
def common_token_checks(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    return (False
        or (yield from self.check_line_break(mode, char))
        or self.check_caret(char)
        or self.check_quote_start(char)
        or (yield from self.check_command_separators(char))
        or (yield from self.check_redirect_io(char)))
def gobble_set_regular(self, mode, char)
Expand source code Browse git
@_register(Mode.SetRegular)
def gobble_set_regular(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    if (yield from self.common_token_checks(mode, char)):
        return False
    if (pr := self.pending_redirect) and char in WHITESPACE:
        token = self.cursor.token
        self.pending_redirect = None
        pr.target = unquote(u16(token))
        del token[:]
        yield pr
    self.cursor.token.append(char)
    return True
def gobble_gap(self, mode, char)
Expand source code Browse git
@_register(Mode.Gap)
def gobble_gap(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    yield from ()
    if char in SEPARATORS:
        return True
    self.mode_finish()
    self.first_after_gap = True
    return False
def gobble_txt(self, mode, char)
Expand source code Browse git
@_register(Mode.Text)
def gobble_txt(self, mode: Mode, char: int) -> Generator[Token, None, bool]:
    if (yield from self.common_token_checks(mode, char)):
        return False
    if char in WHITESPACE:
        yield from self.emit_token()
        self.cursor.token.append(char)
        self.mode_switch(Mode.Whitespace)
        return True
    if char == SLASH and not self.pending_redirect:
        yield from self.emit_token()
    if char == COLON:
        if (yield from self.emit_token()):
            return False
        elif self.next_char() == COLON:
            yield Ctrl.Comment
            return True
        else:
            yield Ctrl.Label
            return False
    try:
        token = SeparatorMap[char]
    except KeyError:
        pass
    else:
        if (yield from self.emit_token()):
            return False
        else:
            yield token
            return True
    self.cursor.token.append(char)
    return True
def preparse(self, text)
Expand source code Browse git
def preparse(self, text: str | buf):
    self.labels = {}

    if not isinstance(text, str):
        text = codecs.decode(text, 'utf8', errors='replace')

    _tail = text[-10:]
    lines = text.splitlines(keepends=False)
    utf16 = array.array('H')

    if _tail.splitlines() != F'{_tail}\n'.splitlines():
        # the text had a trailing line break, which is swallowed by the splitlines method
        lines.append('')

    for k, line in enumerate(lines):
        if k > 0:
            utf16.append(LINEBREAK)
        encoded = line.encode('utf-16le')
        if not encoded:
            continue
        encoded = memoryview(encoded).cast('H')
        offset = len(utf16)
        prefix = re.search('^@?[\\s]*:', line)
        if prefix:
            p = prefix.end()
            if lb := self.label(u16(encoded[p:])):
                self.labels.setdefault(lb, offset + p - 1)
        utf16.extend(encoded)

    self.text = text
    self.code = memoryview(utf16)
class BatchParser (data, state=None)
Expand source code Browse git
class BatchParser:

    def __init__(self, data: str | buf | BatchParser, state: BatchState | None = None):
        if isinstance(data, BatchParser):
            if state is not None:
                raise NotImplementedError
            src = data.lexer
        else:
            src = data
        self.lexer = BatchLexer(src, state)

    @property
    def state(self):
        return self.lexer.state

    @property
    def environment(self):
        return self.state.environment

    def skip_prefix(self, tokens: LookAhead) -> tuple[int, list[Token]]:
        token = tokens.peek()
        prefix = []
        skip = {
            Ctrl.At,
            Ctrl.Semicolon,
            Ctrl.Comma,
            Ctrl.Equals,
        }
        at = 0
        while True:
            if isinstance(token, Word):
                if not token.isspace():
                    break
                prefix.append(token)
            elif token not in skip:
                break
            else:
                prefix.append(token)
                if token == Ctrl.At:
                    at += 1
                else:
                    at = 0
            token = tokens.drop_and_peek()
        return at, prefix

    def command(
        self,
        parent: AstNode,
        tokens: LookAhead,
        redirects: dict[int, RedirectIO],
        in_group: bool,
        silenced: bool,
    ) -> AstCommand | None:
        ast = AstCommand(tokens.offset(), parent, silenced, redirects)
        tok = tokens.peek()
        cmd = ast.fragments

        eat_token = False
        add_space = False
        nsp_merge = True

        nonspace = io.StringIO()

        if not ast.redirects:
            assert not isinstance(tok, RedirectIO)
            tok_upper = tok.upper()
            if tok_upper.startswith('ECHO'):
                if len(tok_upper) > 4 and tok_upper[4] == '.':
                    cmd.append(tok[:4])
                    cmd.append(' ')
                    tok = tok[5:]
                    add_space = False
                else:
                    add_space = True
                    eat_token = True
            elif tok_upper == 'SET':
                self.lexer.parse_set()
                nsp_merge = False
            elif tok_upper == 'GOTO':
                nsp_merge = False
            if add_space or not nsp_merge:
                cmd.append(tok)
            else:
                nonspace.write(tok)
            tok = tokens.drop_and_peek()

        while tok not in (
            Ctrl.Ampersand,
            Ctrl.Pipe,
            Ctrl.NewLine,
            Ctrl.EndOfFile,
        ) or not isinstance(tok, Ctrl):
            if in_group and tok == Ctrl.EndGroup:
                break
            if isinstance(tok, RedirectIO):
                ast.redirects[tok.source] = tok
            elif add_space:
                add_space = False
                if not tok.isspace():
                    cmd.append(' ')
                    if eat_token:
                        tok = tok[1:]
                if tok:
                    cmd.append(tok)
            elif not nsp_merge:
                cmd.append(tok)
            elif tok.isspace() or tok.startswith('/'):
                if nsp := nonspace.getvalue():
                    nonspace.seek(0)
                    nonspace.truncate(0)
                    cmd.append(nsp)
                cmd.append(tok)
            else:
                nonspace.write(tok)
            tok = tokens.drop_and_peek()
        if nsp := nonspace.getvalue():
            cmd.append(nsp)
        if ast:
            return ast

    def redirects(self, tokens: LookAhead):
        redirects = {}
        while isinstance((t := tokens.peek()), RedirectIO):
            redirects[t.source] = t
            tokens.pop()
            tokens.skip_space()
        return redirects

    def pipeline(self, parent: AstNode | None, tokens: LookAhead, in_group: bool, silenced: bool) -> AstPipeline | None:
        ast = AstPipeline(tokens.offset(), parent, silenced)
        while True:
            redirects = self.redirects(tokens)
            if not (cmd := self.group(ast, tokens, redirects, silenced)):
                if not (cmd := self.command(ast, tokens, redirects, in_group, silenced)):
                    break
            ast.parts.append(cmd)
            if not tokens.pop(Ctrl.Pipe) or tokens.peek() == Ctrl.Pipe:
                break
            at, _ = self.skip_prefix(tokens)
            silenced = at > 0
        if ast.parts:
            return ast

    def ifthen(self, parent: AstNode | None, tokens: LookAhead, in_group: bool, silenced: bool) -> AstIf | None:
        offset = tokens.offset()

        if not tokens.pop_string('IF'):
            return None

        casefold = False
        negated = False
        lhs = None
        rhs = None

        token = tokens.word()

        if token.upper() == '/I':
            casefold = True
            token = tokens.word()
        if token.upper() == 'NOT':
            negated = True
            token = tokens.word()
        try:
            variant = AstIfVariant(token.upper())
        except Exception:
            tokens.skip_space()
            variant = None
            lhs = token
            cmp = next(tokens)

            if cmp == Ctrl.Equals:
                if tokens.pop(Ctrl.Equals):
                    cmp = AstIfCmp('==')
                else:
                    raise UnexpectedToken(offset, tokens.peek())
            else:
                try:
                    cmp = AstIfCmp(cmp.upper())
                except Exception:
                    raise UnexpectedToken(offset, cmp)
                if cmp != AstIfCmp.STR and self.state.extensions_version < 1:
                    raise UnexpectedToken(offset, cmp)

            rhs = tokens.consume_nonspace_words()

            try:
                ilh = batchint(lhs)
                irh = batchint(rhs)
            except ValueError:
                pass
            else:
                lhs = ilh
                rhs = irh
        else:
            lhs = unquote(tokens.consume_nonspace_words())
            rhs = None
            cmp = None
            try:
                lhs = batchint(lhs)
            except ValueError:
                pass

        then_do = self.sequence(None, tokens, in_group)

        if then_do is None:
            raise UnexpectedToken(offset, tokens.peek())

        tokens.skip_space()

        ast = AstIf(
            offset,
            parent,
            silenced,
            then_do,
            None,
            variant,
            casefold,
            negated,
            cmp,
            lhs, rhs # type:ignore
        )
        then_do.parent = ast

        if tokens.peek().upper() == 'ELSE':
            tokens.pop()
            ast.else_do = self.sequence(ast, tokens, in_group)

        return ast

    def forloop_options(self, options: str) -> AstForOptions:
        result = AstForOptions()

        if not options:
            return result
        elif (quote := re.search('"(.*?)"', options)):
            options = quote[1]

        parts = options.strip().split()
        count = len(parts)

        for k, part in enumerate(parts, 1):
            key, eq, value = part.partition('=')
            key = key.lower()
            if key == 'usebackq':
                if eq or value:
                    raise ValueError
                result.usebackq = True
            elif not eq:
                raise ValueError
            elif key == 'eol':
                if len(value) != 1:
                    raise ValueError
                result.comment = value
            elif key == 'skip':
                try:
                    result.skip = batchint(value)
                except Exception:
                    raise ValueError
            elif key == 'delims':
                if k == count:
                    _, _, value = options.partition('delims=')
                result.delims = value
            elif key == 'tokens':
                tokens: set[int] = set()
                if value.endswith('*'):
                    result.asterisk = True
                    value = value[:-1]
                    if not value:
                        result.tokens = ()
                        continue
                for x in value.split(','):
                    x, _, y = x.partition('-')
                    x = batchint(x) - 1
                    if x < 0:
                        raise ValueError
                    y = batchint(y) if y else x + 1
                    for t in range(x, y):
                        tokens.add(t)
                result.tokens = tuple(sorted(tokens))
            else:
                raise ValueError

        return result

    def forloop(self, parent: AstNode | None, tokens: LookAhead, in_group: bool, silenced: bool) -> AstFor | None:
        offset = tokens.offset()

        if not tokens.pop_string('FOR'):
            return None

        def isvar(token: str):
            return len(token) == 2 and token.startswith('%')

        path = None
        mode = AstForParserMode.FileSet
        spec = []
        options = ''

        if isvar(variable := tokens.word()):
            variant = AstForVariant.Default
        elif len(variable) != 2 or not variable.startswith('/'):
            raise UnexpectedToken(offset, variable)
        else:
            try:
                variant = AstForVariant(variable[1].upper())
            except ValueError:
                raise UnexpectedToken(offset, variable)
            variable = tokens.word()
            if not isvar(variable):
                if variant == AstForVariant.FileParsing:
                    options = variable
                elif variant == AstForVariant.DescendRecursively:
                    path = unquote(variable)
                else:
                    raise UnexpectedToken(offset, variable)
                variable = tokens.word()
                if not isvar(variable):
                    raise UnexpectedToken(offset, variable)

        if (t := tokens.word()).upper() != 'IN':
            raise UnexpectedToken(offset, t)

        tokens.skip_space()

        if not tokens.pop(Ctrl.NewGroup):
            raise UnexpectedToken(offset, tokens.peek())

        with io.StringIO() as _spec:
            while not tokens.pop(Ctrl.EndGroup):
                if isinstance((t := next(tokens)), RedirectIO):
                    raise UnexpectedToken(offset, t)
                _spec.write(t)
            spec_string = _spec.getvalue().strip()

        tokens.skip_space()

        if not tokens.pop_string('DO'):
            raise UnexpectedToken(offset, tokens.peek())

        if not (body := self.sequence(None, tokens, in_group)):
            raise UnexpectedToken(offset, tokens.peek())

        options = self.forloop_options(options)

        if variant == AstForVariant.FileParsing:
            quote_literal = "'" if options.usebackq else '"'
            quote_command = '`' if options.usebackq else "'"
            for q, m in (
                (quote_literal, AstForParserMode.Literal),
                (quote_command, AstForParserMode.Command),
            ):
                if spec_string.startswith(q):
                    if not spec_string.endswith(q):
                        raise UnexpectedToken(offset, spec_string)
                    mode = m
                    spec = [spec_string[1:-1]]
                    break

        if not spec:
            spec = re.split('[\\s,;]+', spec_string)

        if variant == AstForVariant.NumericLoop:
            init = [0, 0, 0]
            for k, v in enumerate(spec):
                init[k] = batchint(v, 0)
            spec = batchrange(*init)

        ast = AstFor(
            offset,
            parent,
            silenced,
            variant,
            variable[1],
            options,
            body,
            spec,
            spec_string,
            path,
            mode,
        )
        ast.body.parent = ast
        return ast

    def block(self, parent: AstNode | None, tokens: LookAhead, in_group: bool):
        while True:
            while tokens.pop(Ctrl.NewLine):
                continue
            if in_group and tokens.pop(Ctrl.EndGroup):
                break
            if tokens.pop(Ctrl.EndOfFile):
                break
            if sequence := self.sequence(parent, tokens, in_group):
                yield sequence
            else:
                break

    def group(
        self,
        parent: AstNode | None,
        tokens: LookAhead,
        redirects: dict[int, RedirectIO],
        silenced: bool,
    ) -> AstGroup | None:
        offset = tokens.offset()
        if tokens.peek() == Ctrl.NewGroup:
            self.lexer.parse_group()
            tokens.pop()
            group = AstGroup(offset, parent, silenced, redirects)
            group.fragments.extend(self.block(group, tokens, True))
            tokens.skip_space()
            group.redirects.update(self.redirects(tokens))
            return group

    def label(self, tokens: LookAhead, silenced: bool) -> AstLabel | None:
        comment = False
        if (t := tokens.peek()) == Ctrl.Comment:
            comment = True
        elif t != Ctrl.Label:
            return None
        offset = tokens.offset()
        lexer = self.lexer
        lexer.parse_label()
        tokens.pop()
        line = tokens.word()
        if comment:
            label = line
        else:
            label = lexer.label(line, uppercase=False)
            if (x := lexer.labels[label.upper()]) != offset - 1:
                raise RuntimeError(F'Expected offset for label {label} to be {offset}, got {x} instead.')
        return AstLabel(offset, None, silenced, line, label, comment)

    def statement(self, parent: AstNode | None, tokens: LookAhead, in_group: bool):
        at, _ = self.skip_prefix(tokens)
        silenced = at > 0
        if at <= 1 and (s := self.label(tokens, silenced)):
            return s
        if s := self.ifthen(parent, tokens, in_group, silenced):
            return s
        if s := self.forloop(parent, tokens, in_group, silenced):
            return s
        return self.pipeline(parent, tokens, in_group, silenced)

    def sequence(self, parent: AstNode | None, tokens: LookAhead, in_group: bool) -> AstSequence | None:
        tokens.skip_space()
        head = self.statement(parent, tokens, in_group)
        if head is None:
            return None
        node = AstSequence(head.offset, parent, head)
        head.parent = node
        tokens.skip_space()
        while True:
            if tokens.pop(Ctrl.Ampersand):
                if tokens.pop(Ctrl.Ampersand):
                    condition = AstCondition.Success
                else:
                    condition = AstCondition.NoCheck
            elif tokens.pop(Ctrl.Pipe):
                condition = AstCondition.Failure
            else:
                break
            tokens.skip_space()
            if not (statement := self.statement(node, tokens, in_group)):
                raise EmulatorException('Failed to parse conditional statement.')
            node.tail.append(
                AstConditionalStatement(statement.offset, node, condition, statement))
            tokens.skip_space()
        return node

    def parse(self, offset: int):
        tokens = LookAhead(self.lexer, offset)
        while True:
            try:
                yield from self.block(None, tokens, False)
            except UnexpectedToken as ut:
                yield AstError(ut.offset, None, ut.token, ut.error)
            else:
                break

Instance variables

var state
Expand source code Browse git
@property
def state(self):
    return self.lexer.state
var environment
Expand source code Browse git
@property
def environment(self):
    return self.state.environment

Methods

def skip_prefix(self, tokens)
Expand source code Browse git
def skip_prefix(self, tokens: LookAhead) -> tuple[int, list[Token]]:
    token = tokens.peek()
    prefix = []
    skip = {
        Ctrl.At,
        Ctrl.Semicolon,
        Ctrl.Comma,
        Ctrl.Equals,
    }
    at = 0
    while True:
        if isinstance(token, Word):
            if not token.isspace():
                break
            prefix.append(token)
        elif token not in skip:
            break
        else:
            prefix.append(token)
            if token == Ctrl.At:
                at += 1
            else:
                at = 0
        token = tokens.drop_and_peek()
    return at, prefix
def command(self, parent, tokens, redirects, in_group, silenced)
Expand source code Browse git
def command(
    self,
    parent: AstNode,
    tokens: LookAhead,
    redirects: dict[int, RedirectIO],
    in_group: bool,
    silenced: bool,
) -> AstCommand | None:
    ast = AstCommand(tokens.offset(), parent, silenced, redirects)
    tok = tokens.peek()
    cmd = ast.fragments

    eat_token = False
    add_space = False
    nsp_merge = True

    nonspace = io.StringIO()

    if not ast.redirects:
        assert not isinstance(tok, RedirectIO)
        tok_upper = tok.upper()
        if tok_upper.startswith('ECHO'):
            if len(tok_upper) > 4 and tok_upper[4] == '.':
                cmd.append(tok[:4])
                cmd.append(' ')
                tok = tok[5:]
                add_space = False
            else:
                add_space = True
                eat_token = True
        elif tok_upper == 'SET':
            self.lexer.parse_set()
            nsp_merge = False
        elif tok_upper == 'GOTO':
            nsp_merge = False
        if add_space or not nsp_merge:
            cmd.append(tok)
        else:
            nonspace.write(tok)
        tok = tokens.drop_and_peek()

    while tok not in (
        Ctrl.Ampersand,
        Ctrl.Pipe,
        Ctrl.NewLine,
        Ctrl.EndOfFile,
    ) or not isinstance(tok, Ctrl):
        if in_group and tok == Ctrl.EndGroup:
            break
        if isinstance(tok, RedirectIO):
            ast.redirects[tok.source] = tok
        elif add_space:
            add_space = False
            if not tok.isspace():
                cmd.append(' ')
                if eat_token:
                    tok = tok[1:]
            if tok:
                cmd.append(tok)
        elif not nsp_merge:
            cmd.append(tok)
        elif tok.isspace() or tok.startswith('/'):
            if nsp := nonspace.getvalue():
                nonspace.seek(0)
                nonspace.truncate(0)
                cmd.append(nsp)
            cmd.append(tok)
        else:
            nonspace.write(tok)
        tok = tokens.drop_and_peek()
    if nsp := nonspace.getvalue():
        cmd.append(nsp)
    if ast:
        return ast
def redirects(self, tokens)
Expand source code Browse git
def redirects(self, tokens: LookAhead):
    redirects = {}
    while isinstance((t := tokens.peek()), RedirectIO):
        redirects[t.source] = t
        tokens.pop()
        tokens.skip_space()
    return redirects
def pipeline(self, parent, tokens, in_group, silenced)
Expand source code Browse git
def pipeline(self, parent: AstNode | None, tokens: LookAhead, in_group: bool, silenced: bool) -> AstPipeline | None:
    ast = AstPipeline(tokens.offset(), parent, silenced)
    while True:
        redirects = self.redirects(tokens)
        if not (cmd := self.group(ast, tokens, redirects, silenced)):
            if not (cmd := self.command(ast, tokens, redirects, in_group, silenced)):
                break
        ast.parts.append(cmd)
        if not tokens.pop(Ctrl.Pipe) or tokens.peek() == Ctrl.Pipe:
            break
        at, _ = self.skip_prefix(tokens)
        silenced = at > 0
    if ast.parts:
        return ast
def ifthen(self, parent, tokens, in_group, silenced)
Expand source code Browse git
def ifthen(self, parent: AstNode | None, tokens: LookAhead, in_group: bool, silenced: bool) -> AstIf | None:
    offset = tokens.offset()

    if not tokens.pop_string('IF'):
        return None

    casefold = False
    negated = False
    lhs = None
    rhs = None

    token = tokens.word()

    if token.upper() == '/I':
        casefold = True
        token = tokens.word()
    if token.upper() == 'NOT':
        negated = True
        token = tokens.word()
    try:
        variant = AstIfVariant(token.upper())
    except Exception:
        tokens.skip_space()
        variant = None
        lhs = token
        cmp = next(tokens)

        if cmp == Ctrl.Equals:
            if tokens.pop(Ctrl.Equals):
                cmp = AstIfCmp('==')
            else:
                raise UnexpectedToken(offset, tokens.peek())
        else:
            try:
                cmp = AstIfCmp(cmp.upper())
            except Exception:
                raise UnexpectedToken(offset, cmp)
            if cmp != AstIfCmp.STR and self.state.extensions_version < 1:
                raise UnexpectedToken(offset, cmp)

        rhs = tokens.consume_nonspace_words()

        try:
            ilh = batchint(lhs)
            irh = batchint(rhs)
        except ValueError:
            pass
        else:
            lhs = ilh
            rhs = irh
    else:
        lhs = unquote(tokens.consume_nonspace_words())
        rhs = None
        cmp = None
        try:
            lhs = batchint(lhs)
        except ValueError:
            pass

    then_do = self.sequence(None, tokens, in_group)

    if then_do is None:
        raise UnexpectedToken(offset, tokens.peek())

    tokens.skip_space()

    ast = AstIf(
        offset,
        parent,
        silenced,
        then_do,
        None,
        variant,
        casefold,
        negated,
        cmp,
        lhs, rhs # type:ignore
    )
    then_do.parent = ast

    if tokens.peek().upper() == 'ELSE':
        tokens.pop()
        ast.else_do = self.sequence(ast, tokens, in_group)

    return ast
def forloop_options(self, options)
Expand source code Browse git
def forloop_options(self, options: str) -> AstForOptions:
    result = AstForOptions()

    if not options:
        return result
    elif (quote := re.search('"(.*?)"', options)):
        options = quote[1]

    parts = options.strip().split()
    count = len(parts)

    for k, part in enumerate(parts, 1):
        key, eq, value = part.partition('=')
        key = key.lower()
        if key == 'usebackq':
            if eq or value:
                raise ValueError
            result.usebackq = True
        elif not eq:
            raise ValueError
        elif key == 'eol':
            if len(value) != 1:
                raise ValueError
            result.comment = value
        elif key == 'skip':
            try:
                result.skip = batchint(value)
            except Exception:
                raise ValueError
        elif key == 'delims':
            if k == count:
                _, _, value = options.partition('delims=')
            result.delims = value
        elif key == 'tokens':
            tokens: set[int] = set()
            if value.endswith('*'):
                result.asterisk = True
                value = value[:-1]
                if not value:
                    result.tokens = ()
                    continue
            for x in value.split(','):
                x, _, y = x.partition('-')
                x = batchint(x) - 1
                if x < 0:
                    raise ValueError
                y = batchint(y) if y else x + 1
                for t in range(x, y):
                    tokens.add(t)
            result.tokens = tuple(sorted(tokens))
        else:
            raise ValueError

    return result
def forloop(self, parent, tokens, in_group, silenced)
Expand source code Browse git
def forloop(self, parent: AstNode | None, tokens: LookAhead, in_group: bool, silenced: bool) -> AstFor | None:
    offset = tokens.offset()

    if not tokens.pop_string('FOR'):
        return None

    def isvar(token: str):
        return len(token) == 2 and token.startswith('%')

    path = None
    mode = AstForParserMode.FileSet
    spec = []
    options = ''

    if isvar(variable := tokens.word()):
        variant = AstForVariant.Default
    elif len(variable) != 2 or not variable.startswith('/'):
        raise UnexpectedToken(offset, variable)
    else:
        try:
            variant = AstForVariant(variable[1].upper())
        except ValueError:
            raise UnexpectedToken(offset, variable)
        variable = tokens.word()
        if not isvar(variable):
            if variant == AstForVariant.FileParsing:
                options = variable
            elif variant == AstForVariant.DescendRecursively:
                path = unquote(variable)
            else:
                raise UnexpectedToken(offset, variable)
            variable = tokens.word()
            if not isvar(variable):
                raise UnexpectedToken(offset, variable)

    if (t := tokens.word()).upper() != 'IN':
        raise UnexpectedToken(offset, t)

    tokens.skip_space()

    if not tokens.pop(Ctrl.NewGroup):
        raise UnexpectedToken(offset, tokens.peek())

    with io.StringIO() as _spec:
        while not tokens.pop(Ctrl.EndGroup):
            if isinstance((t := next(tokens)), RedirectIO):
                raise UnexpectedToken(offset, t)
            _spec.write(t)
        spec_string = _spec.getvalue().strip()

    tokens.skip_space()

    if not tokens.pop_string('DO'):
        raise UnexpectedToken(offset, tokens.peek())

    if not (body := self.sequence(None, tokens, in_group)):
        raise UnexpectedToken(offset, tokens.peek())

    options = self.forloop_options(options)

    if variant == AstForVariant.FileParsing:
        quote_literal = "'" if options.usebackq else '"'
        quote_command = '`' if options.usebackq else "'"
        for q, m in (
            (quote_literal, AstForParserMode.Literal),
            (quote_command, AstForParserMode.Command),
        ):
            if spec_string.startswith(q):
                if not spec_string.endswith(q):
                    raise UnexpectedToken(offset, spec_string)
                mode = m
                spec = [spec_string[1:-1]]
                break

    if not spec:
        spec = re.split('[\\s,;]+', spec_string)

    if variant == AstForVariant.NumericLoop:
        init = [0, 0, 0]
        for k, v in enumerate(spec):
            init[k] = batchint(v, 0)
        spec = batchrange(*init)

    ast = AstFor(
        offset,
        parent,
        silenced,
        variant,
        variable[1],
        options,
        body,
        spec,
        spec_string,
        path,
        mode,
    )
    ast.body.parent = ast
    return ast
def block(self, parent, tokens, in_group)
Expand source code Browse git
def block(self, parent: AstNode | None, tokens: LookAhead, in_group: bool):
    while True:
        while tokens.pop(Ctrl.NewLine):
            continue
        if in_group and tokens.pop(Ctrl.EndGroup):
            break
        if tokens.pop(Ctrl.EndOfFile):
            break
        if sequence := self.sequence(parent, tokens, in_group):
            yield sequence
        else:
            break
def group(self, parent, tokens, redirects, silenced)
Expand source code Browse git
def group(
    self,
    parent: AstNode | None,
    tokens: LookAhead,
    redirects: dict[int, RedirectIO],
    silenced: bool,
) -> AstGroup | None:
    offset = tokens.offset()
    if tokens.peek() == Ctrl.NewGroup:
        self.lexer.parse_group()
        tokens.pop()
        group = AstGroup(offset, parent, silenced, redirects)
        group.fragments.extend(self.block(group, tokens, True))
        tokens.skip_space()
        group.redirects.update(self.redirects(tokens))
        return group
def label(self, tokens, silenced)
Expand source code Browse git
def label(self, tokens: LookAhead, silenced: bool) -> AstLabel | None:
    comment = False
    if (t := tokens.peek()) == Ctrl.Comment:
        comment = True
    elif t != Ctrl.Label:
        return None
    offset = tokens.offset()
    lexer = self.lexer
    lexer.parse_label()
    tokens.pop()
    line = tokens.word()
    if comment:
        label = line
    else:
        label = lexer.label(line, uppercase=False)
        if (x := lexer.labels[label.upper()]) != offset - 1:
            raise RuntimeError(F'Expected offset for label {label} to be {offset}, got {x} instead.')
    return AstLabel(offset, None, silenced, line, label, comment)
def statement(self, parent, tokens, in_group)
Expand source code Browse git
def statement(self, parent: AstNode | None, tokens: LookAhead, in_group: bool):
    at, _ = self.skip_prefix(tokens)
    silenced = at > 0
    if at <= 1 and (s := self.label(tokens, silenced)):
        return s
    if s := self.ifthen(parent, tokens, in_group, silenced):
        return s
    if s := self.forloop(parent, tokens, in_group, silenced):
        return s
    return self.pipeline(parent, tokens, in_group, silenced)
def sequence(self, parent, tokens, in_group)
Expand source code Browse git
def sequence(self, parent: AstNode | None, tokens: LookAhead, in_group: bool) -> AstSequence | None:
    tokens.skip_space()
    head = self.statement(parent, tokens, in_group)
    if head is None:
        return None
    node = AstSequence(head.offset, parent, head)
    head.parent = node
    tokens.skip_space()
    while True:
        if tokens.pop(Ctrl.Ampersand):
            if tokens.pop(Ctrl.Ampersand):
                condition = AstCondition.Success
            else:
                condition = AstCondition.NoCheck
        elif tokens.pop(Ctrl.Pipe):
            condition = AstCondition.Failure
        else:
            break
        tokens.skip_space()
        if not (statement := self.statement(node, tokens, in_group)):
            raise EmulatorException('Failed to parse conditional statement.')
        node.tail.append(
            AstConditionalStatement(statement.offset, node, condition, statement))
        tokens.skip_space()
    return node
def parse(self, offset)
Expand source code Browse git
def parse(self, offset: int):
    tokens = LookAhead(self.lexer, offset)
    while True:
        try:
            yield from self.block(None, tokens, False)
        except UnexpectedToken as ut:
            yield AstError(ut.offset, None, ut.token, ut.error)
        else:
            break
class BatchState (delayexpand=False, extensions_enabled=True, extensions_version=2, environment=None, file_system=None, username='Administrator', hostname=None, now=None, cwd='C:\\', filename='', echo=True, codec='cp1252')
Expand source code Browse git
class BatchState:

    name: str
    args: list[str]

    now: datetime
    start_time: datetime

    environment_stack: list[dict[str, str | RetainVariable]]
    delayexpand_stack: list[bool]
    cmdextended_stack: list[bool]

    _for_loops: list[dict[str, str]]
    file_system: dict[str, str]

    def __init__(
        self,
        delayexpand: bool = False,
        extensions_enabled: bool = True,
        extensions_version: int = 2,
        environment: dict | None = None,
        file_system: dict | None = None,
        username: str = 'Administrator',
        hostname: str | None = None,
        now: int | float | str | datetime | None = None,
        cwd: str = 'C:\\',
        filename: str | None = '',
        echo: bool = True,
        codec: str = 'cp1252',
    ):
        self.extensions_version = extensions_version
        file_system = file_system or {}
        environment = environment or {}
        if hostname is None:
            hostname = str(uuid4())
        for key, value in _DEFAULT_ENVIRONMENT.items():
            environment.setdefault(
                key.upper(),
                value.format(h=hostname, u=username)
            )
        if isinstance(now, str):
            now = isodate(now)
        if isinstance(now, (int, float)):
            now = date_from_timestamp(now)
        if now is None:
            now = datetime.now()
        self.cwd = cwd
        self.now = now
        self.start_time = now
        seed(self.now.timestamp())
        self.hostname = hostname
        self.username = username
        self.labels = {}
        self._for_loops = []
        self.environment_stack = [environment]
        self.delayexpand_stack = [delayexpand]
        self.cmdextended_stack = [extensions_enabled]
        self.file_system = file_system
        self.dirstack = []
        self.linebreaks = []
        self.name = filename or F'{uuid4()}.bat'
        self.args = []
        self._cmd = ''
        self.ec = None
        self.echo = echo
        self.codec = codec

    @property
    def cwd(self):
        return self._cwd

    @cwd.setter
    def cwd(self, new: str):
        new = new.replace('/', '\\')
        if not new.endswith('\\'):
            new = F'{new}\\'
        if not ntpath.isabs(new):
            new = ntpath.join(self.cwd, new)
        if not ntpath.isabs(new):
            raise ValueError(F'Invalid absolute path: {new}')
        self._cwd = ntpath.normcase(ntpath.normpath(new))

    @property
    def ec(self) -> int | ErrorZero:
        return self.errorlevel

    @ec.setter
    def ec(self, value: int | ErrorZero | None):
        ec = value or 0
        self.environment['ERRORLEVEL'] = str(ec)
        self.errorlevel = ec

    @property
    def command_line(self):
        return self._cmd

    @command_line.setter
    def command_line(self, value: str):
        self._cmd = value
        self.args = value.split()

    def envar(self, name: str, default: str | None = None) -> str | RetainVariable:
        name = name.upper()
        if name in (e := self.environment):
            return e[name]
        elif name == 'DATE':
            return self.now.strftime('%Y-%m-%d')
        elif name == 'TIME':
            time = self.now.strftime('%M:%S,%f')
            return F'{self.now.hour:2d}:{time:.8}'
        elif name == 'RANDOM':
            return str(randrange(0, 32767))
        elif name == 'ERRORLEVEL':
            return str(self.ec)
        elif name == 'CD':
            return self.cwd
        elif name == 'CMDCMDLINE':
            line = self.envar('COMSPEC', 'cmd.exe')
            if args := self.args:
                args = ' '.join(args)
                line = F'{line} /c "{args}"'
            return line
        elif name == 'CMDEXTVERSION':
            return str(self.extensions_version)
        elif name == 'HIGHESTNUMANODENUMBER':
            return '0'
        elif default is not None:
            return default
        else:
            raise MissingVariable

    def resolve_path(self, path: str) -> str:
        if not ntpath.isabs(path):
            path = F'{self.cwd}{path}'
        return ntpath.normcase(ntpath.normpath(path))

    def create_file(self, path: str, data: str = ''):
        self.file_system[self.resolve_path(path)] = data

    def append_file(self, path: str, data: str):
        path = self.resolve_path(path)
        if left := self.file_system.get(path, None):
            data = F'{left}{data}'
        self.file_system[path] = data

    def remove_file(self, path: str):
        self.file_system.pop(self.resolve_path(path), None)

    def ingest_file(self, path: str) -> str | None:
        return self.file_system.get(self.resolve_path(path))

    def exists_file(self, path: str) -> bool:
        return self.resolve_path(path) in self.file_system

    def sizeof_file(self, path: str) -> int:
        if data := self.ingest_file(path):
            return len(data)
        return -1

    def new_forloop(self) -> dict[str, str]:
        new = {}
        old = self.for_loop_variables
        if old is not None:
            new.update(old)
        self._for_loops.append(new)
        return new

    def end_forloop(self):
        self._for_loops.pop()

    @property
    def environment(self):
        return self.environment_stack[-1]

    @property
    def delayexpand(self):
        return self.delayexpand_stack[-1]

    @delayexpand.setter
    def delayexpand(self, v):
        self.delayexpand_stack[-1] = v

    @property
    def cmdextended(self):
        return self.cmdextended_stack[-1]

    @cmdextended.setter
    def cmdextended(self, v):
        self.cmdextended_stack[-1] = v

    @property
    def for_loop_variables(self):
        if vars := self._for_loops:
            return vars[-1]
        else:
            return None

Class variables

var name

The type of the None singleton.

var args

The type of the None singleton.

var now

The type of the None singleton.

var start_time

The type of the None singleton.

var environment_stack

The type of the None singleton.

var delayexpand_stack

The type of the None singleton.

var cmdextended_stack

The type of the None singleton.

var file_system

The type of the None singleton.

Instance variables

var cwd
Expand source code Browse git
@property
def cwd(self):
    return self._cwd
var ec
Expand source code Browse git
@property
def ec(self) -> int | ErrorZero:
    return self.errorlevel
var command_line
Expand source code Browse git
@property
def command_line(self):
    return self._cmd
var environment
Expand source code Browse git
@property
def environment(self):
    return self.environment_stack[-1]
var delayexpand
Expand source code Browse git
@property
def delayexpand(self):
    return self.delayexpand_stack[-1]
var cmdextended
Expand source code Browse git
@property
def cmdextended(self):
    return self.cmdextended_stack[-1]
var for_loop_variables
Expand source code Browse git
@property
def for_loop_variables(self):
    if vars := self._for_loops:
        return vars[-1]
    else:
        return None

Methods

def envar(self, name, default=None)
Expand source code Browse git
def envar(self, name: str, default: str | None = None) -> str | RetainVariable:
    name = name.upper()
    if name in (e := self.environment):
        return e[name]
    elif name == 'DATE':
        return self.now.strftime('%Y-%m-%d')
    elif name == 'TIME':
        time = self.now.strftime('%M:%S,%f')
        return F'{self.now.hour:2d}:{time:.8}'
    elif name == 'RANDOM':
        return str(randrange(0, 32767))
    elif name == 'ERRORLEVEL':
        return str(self.ec)
    elif name == 'CD':
        return self.cwd
    elif name == 'CMDCMDLINE':
        line = self.envar('COMSPEC', 'cmd.exe')
        if args := self.args:
            args = ' '.join(args)
            line = F'{line} /c "{args}"'
        return line
    elif name == 'CMDEXTVERSION':
        return str(self.extensions_version)
    elif name == 'HIGHESTNUMANODENUMBER':
        return '0'
    elif default is not None:
        return default
    else:
        raise MissingVariable
def resolve_path(self, path)
Expand source code Browse git
def resolve_path(self, path: str) -> str:
    if not ntpath.isabs(path):
        path = F'{self.cwd}{path}'
    return ntpath.normcase(ntpath.normpath(path))
def create_file(self, path, data='')
Expand source code Browse git
def create_file(self, path: str, data: str = ''):
    self.file_system[self.resolve_path(path)] = data
def append_file(self, path, data)
Expand source code Browse git
def append_file(self, path: str, data: str):
    path = self.resolve_path(path)
    if left := self.file_system.get(path, None):
        data = F'{left}{data}'
    self.file_system[path] = data
def remove_file(self, path)
Expand source code Browse git
def remove_file(self, path: str):
    self.file_system.pop(self.resolve_path(path), None)
def ingest_file(self, path)
Expand source code Browse git
def ingest_file(self, path: str) -> str | None:
    return self.file_system.get(self.resolve_path(path))
def exists_file(self, path)
Expand source code Browse git
def exists_file(self, path: str) -> bool:
    return self.resolve_path(path) in self.file_system
def sizeof_file(self, path)
Expand source code Browse git
def sizeof_file(self, path: str) -> int:
    if data := self.ingest_file(path):
        return len(data)
    return -1
def new_forloop(self)
Expand source code Browse git
def new_forloop(self) -> dict[str, str]:
    new = {}
    old = self.for_loop_variables
    if old is not None:
        new.update(old)
    self._for_loops.append(new)
    return new
def end_forloop(self)
Expand source code Browse git
def end_forloop(self):
    self._for_loops.pop()