Module refinery.lib.batch
Expand source code Browse git
from __future__ import annotations
import array
import codecs
import enum
import itertools
import ntpath
import re
from abc import ABC, abstractmethod
from typing import Generator, Generic, List, TypeVar, Union, overload
from refinery.lib.deobfuscation import cautious_eval_or_default
from refinery.lib.patterns import formats
from refinery.lib.types import buf
BatchCode = Union[str, List['BatchCode']]
Block = List[BatchCode]
IntOrStr = TypeVar('IntOrStr', int, str)
def batchint(expr: str):
m = int(expr.startswith('-'))
if expr[m:m + 2] in ('0x', '0X'):
base = 16
elif expr[m:m + 1] == '0':
base = 8
else:
base = 10
return int(expr, base)
class IfEq(Generic[IntOrStr]):
def __init__(
self,
lhs: IntOrStr,
rhs: IntOrStr,
):
self.lhs = lhs
self.rhs = rhs
def EQU(self):
return self.lhs == self.rhs
def LEQ(self):
return self.lhs <= self.rhs
def LSS(self):
return self.lhs < self.rhs
_PAREN_OPEN = 0x28
_PAREN_CLOSE = 0x29
_CARET = 0x5E
_QUOTE = 0x22
_LINEBREAK = 0x0A
class EmulatorError(Exception):
pass
class UnexpectedToken(EmulatorError):
def __init__(self, token: str) -> None:
super().__init__(F'Unexpected token: {token}')
class ExecutionResult(ABC):
@abstractmethod
def longjump(self) -> bool:
pass
class Exit(ExecutionResult):
def __init__(self, code: int = 0, script_only: bool = True):
self.code = code
self.script_only = script_only
def longjump(self) -> bool:
return not self.script_only
class Goto(ExecutionResult):
def __init__(self, label: str):
self.label = label
def longjump(self) -> bool:
return True
class InvalidLabel(EmulatorError):
def __init__(self, label: str):
super().__init__(F'The following label was not found: {label}')
class EmulatedCommand(str):
pass
class Condition(str, enum.Enum):
Always = '&'
IfOk = '&&'
IfNotOk = '||'
class If(enum.IntFlag):
Inactive = 0b0000
Active = 0b0001
Block = 0b0010
Then = 0b0100
Else = 0b1000
def skip_block(self):
skip = If.Then not in self
if If.Else in self:
skip = not skip
return skip
class BatchFileEmulator:
environments: list[dict[str, str]]
code: Block
labels: dict[str, list[int]]
args: list[str]
def __init__(
self,
data: str | buf,
delayed_expansion: bool = False,
extensions_enabled: bool = True,
extensions_version: int = 2,
file_system: dict | None = None,
cwd: str = 'C:\\'
):
self.delayed_expansion = delayed_expansion
self.extensions_version = extensions_version
self.extensions_enabled = extensions_enabled
self.file_sytem_seed = file_system or {}
self.cwd = cwd
self.parse(data)
@property
def cwd(self):
return self._cwd
@cwd.setter
def cwd(self, new: str):
new = new.replace('/', '\\')
if not new.endswith('\\'):
new = F'{new}\\'
if not ntpath.isabs(new):
raise ValueError(F'Invalid absolute path: {new}')
self._cwd = ntpath.normcase(ntpath.normpath(new))
@property
def ec(self) -> int:
return self.errorlevel
@ec.setter
def ec(self, value: int | None):
ec = value or 0
self.environment['ERRORLEVEL'] = str(ec)
self.errorlevel = ec
def reset(self):
self.labels = {}
self.environments = [{}]
self.delayexpands = [self.delayed_expansion]
self.ext_settings = [self.extensions_enabled]
self.file_system = dict(self.file_sytem_seed)
self.dirstack = []
self.args = []
self.ec = None
def _resolved(self, path: str) -> str:
if not ntpath.isabs(path):
path = F'{self.cwd}{path}'
return ntpath.normcase(ntpath.normpath(path))
def create_file(self, path: str, data: str = ''):
self.file_system[self._resolved(path)] = data
def append_file(self, path: str, data: str):
path = self._resolved(path)
if left := self.file_system.get(path, None):
data = F'{left}{data}'
self.file_system[path] = data
def remove_file(self, path: str):
self.file_system.pop(self._resolved(path), None)
def ingest_file(self, path: str) -> str | None:
return self.file_system.get(self._resolved(path))
def exists_file(self, path: str) -> bool:
return self._resolved(path) in self.file_system
@property
def environment(self):
return self.environments[-1]
@property
def delayexpand(self):
return self.delayexpands[-1]
@property
def ext_setting(self):
return self.ext_settings[-1]
@staticmethod
def split_head(
expression: str,
toupper: bool = False,
uncaret: bool = True,
unquote: bool = False,
terminator_letters: bytes = B'\x20\x09\x0B',
terminator_strings: tuple[bytes, ...] = (),
):
quote = False
caret = False
token = array.array("H")
utf16 = expression.encode('utf-16le')
utf16 = memoryview(utf16).cast('H')
t1 = terminator_letters
t2 = terminator_strings
for k, char in enumerate(utf16):
if not quote and not caret:
if char in t1 or any(utf16[k:k + len(t)] == t for t in t2):
tail = expression[k:]
break
if char == _QUOTE:
quote = not quote
if unquote:
continue
elif quote:
pass
elif caret:
caret = False
elif char == _CARET:
caret = True
if uncaret:
continue
token.append(char)
else:
tail = ''
head = token.tobytes().decode('utf-16le')
if toupper:
head = head.upper()
return head, tail.lstrip()
@overload
def expand(self, block: str, delay: bool = False) -> str:
...
@overload
def expand(self, block: list, delay: bool = False) -> list:
...
def expand(self, block: BatchCode, delay: bool = False):
def expansion(match: re.Match[str]):
name = match.group(1)
base = self.environment.get(name.upper(), '')
if not (modifier := match.group(2)):
return base
if '=' in modifier:
old, _, new = modifier.partition('=')
kwargs = {}
if old.startswith('~'):
old = old[1:]
kwargs.update(count=1)
return base.replace(old, new, **kwargs)
else:
if not modifier.startswith(':~'):
raise EmulatorError
offset, _, length = modifier[2:].partition(',')
offset = batchint(offset)
if offset < 0:
offset = max(0, len(base) + offset)
if length:
end = offset + batchint(length)
else:
end = len(base)
return base[offset:end]
if delay:
pattern = r'!([^!:\n]*)()!'
else:
pattern = rf'%([^%:\n]*)(:(?:~{formats.integer}(?:,{formats.integer})?|[^=%\n]+=[^%\r\n]*))?%'
if isinstance(block, str):
return re.sub(pattern, expansion, block)
else:
return [self.expand(child) for child in block]
def execute_set(self, command: str):
check, rest = self.split_head(command, toupper=True)
if check == '/P':
self.ec = yield EmulatedCommand(F'set {command}')
return
if check == '/A':
arithmetic = True
command = rest
else:
arithmetic = False
if not command:
return
command, _ = self.split_head(command, terminator_letters=B'')
if command.startswith('"'):
# This is how it works based on testing, even if it seems insane.
command, _, what = command[1:].rpartition('"')
command = command or what
if arithmetic:
integers = {}
updated = {}
for name, value in self.environment.items():
try:
integers[name] = batchint(value)
except ValueError:
pass
for assignment in command.split(','):
assignment = assignment.strip()
name, _, expression = assignment.partition('=')
expression = cautious_eval_or_default(expression, environment=integers)
if expression is not None:
integers[name] = expression
updated[name] = str(expression)
self.environment.update(updated)
else:
name, _, content = command.partition('=')
name = name.upper()
content, _ = self.split_head(content, terminator_letters=B'')
if not content:
self.environment.pop(name, None)
else:
self.environment[name] = content
def execute_if(self, command: str):
casefold = False
negate = False
check, rest = self.split_head(command, toupper=True)
if check == '/I':
casefold = True
command, check, rest = rest, *self.split_head(rest)
if check == 'NOT':
negate = True
command, check, rest = rest, *self.split_head(rest)
if check == 'ERRORLEVEL':
limit, rest = self.split_head(rest)
limit = int(limit.strip(), 10)
condition = limit <= self.ec
elif check == 'CMDEXTVERSION':
limit, rest = self.split_head(rest)
limit = int(limit.strip(), 10)
condition = limit <= self.extensions_version
elif check == 'EXIST':
path, rest = self.split_head(rest, unquote=True)
condition = self.exists_file(path)
elif check == 'DEFINED':
name, rest = self.split_head(rest)
condition = name.upper() in self.environment
else:
lhs, rest = self.split_head(
command,
toupper=False,
unquote=True,
terminator_strings=(B'==',)
)
if rest.startswith('=='):
rest = rest[2:].lstrip()
rhs, rest = self.split_head(rest, toupper=False, unquote=True)
if casefold:
lhs = lhs.casefold()
rhs = rhs.casefold()
condition = lhs == rhs
else:
cmp, rest = self.split_head(rest)
if self.extensions_version < 1:
raise UnexpectedToken(cmp)
rhs, rest = self.split_head(rest)
if cmp == 'GTR':
rhs, lhs, cmp = lhs, rhs, 'LSS'
if cmp == 'GEQ':
rhs, lhs, cmp = lhs, rhs, 'LEQ'
if cmp == 'NEQ':
negate, cmp = not negate, 'EQU'
try:
ilh = batchint(lhs)
irh = batchint(rhs)
except ValueError:
pair = IfEq(lhs, rhs)
else:
pair = IfEq(ilh, irh)
if cmp == 'EQU':
condition = pair.EQU()
elif cmp == 'LSS':
condition = pair.LSS()
elif cmp == 'LEQ':
condition = pair.LEQ()
else:
raise UnexpectedToken(cmp)
if negate:
condition = not condition
return condition, rest
def _commands(self, line: str):
quote = False
caret = False
check = 0
again = None
for k, char in enumerate(line):
if again:
if quote or caret:
raise EmulatorError
how = None
end = None
if again == char:
how = 2 * again
end = k + 1
elif again == Condition.Always:
how = again
end = k
again = None
if end is not None and how is not None:
cmd = line[check:k - 1]
yield cmd.lstrip(), Condition(how)
check = end
continue
if char == '"':
quote = not quote
continue
if char == '\n':
raise ValueError
if quote:
continue
if caret:
caret = False
continue
if char == '^':
caret = True
continue
if char in '|&':
again = char
if (rest := line[check:]) and rest.strip():
yield rest.lstrip(), Condition.Always
def _check_condition(self, condition: Condition):
if condition == Condition.Always:
return True
if condition == Condition.IfNotOk:
return self.ec != 0
if condition == Condition.IfOk:
return self.ec == 0
raise TypeError(condition)
def goto(self, index: list[int]) -> tuple[Block, BatchCode, int]:
if not index:
index = [0]
line = 0
code = cursor = self.code
for line in index:
code, cursor = cursor, cursor[line]
assert isinstance(code, list)
return code, cursor, line
def emulate(self, *args: str) -> Generator[EmulatedCommand, int | None, ExecutionResult]:
index = [0]
self.args[:] = args
while True:
block, _, offset = self.goto(index)
state = yield from self.emulate_block(
block,
offset=offset,
expand=True,
)
if isinstance(state, Goto):
label = state.label.upper()
if label == 'EOF':
return Exit()
try:
index = self.labels[label]
except KeyError as KE:
raise InvalidLabel(label) from KE
else:
continue
if isinstance(state, Exit):
self.ec = state.code
return state
raise TypeError(state)
def emulate_block(
self,
block: Block,
offset: int = 0,
expand: bool = False,
) -> Generator[EmulatedCommand, int | None, ExecutionResult]:
it = block if offset <= 0 else itertools.islice(block, offset, None)
ifelse = If.Inactive
for code in it:
if expand:
code = self.expand(code)
if If.Block in ifelse:
if not isinstance(code, list):
raise EmulatorError(F'Expected a block while parsing If/Else; {ifelse!r}')
if not ifelse.skip_block():
exit = (yield from self.emulate_block(code))
if exit.longjump():
return exit
if If.Else in ifelse:
ifelse = If.Inactive
else:
ifelse |= If.Else
ifelse &= ~If.Block
continue
if isinstance(code, list):
if ifelse != If.Inactive:
raise EmulatorError('Unexpected block in the middle of if/else statement.')
exit = (yield from self.emulate_block(code))
if exit.longjump():
return exit
continue
condition = Condition.Always
for command, next_condition in self._commands(code):
if not self._check_condition(condition):
break
condition = next_condition
if self.delayexpand:
command = self.expand(command, True)
head, tail = self.split_head(
command, toupper=True, uncaret=False)
head = head.lstrip('@')
if head == 'SET':
yield from self.execute_set(tail)
elif head == 'SETLOCAL':
setting = tail.strip().upper()
delay = {
'DISABLEDELAYEDEXPANSION': False,
'ENABLEDELAYEDEXPANSION' : True,
}.get(setting, self.delayexpand)
cmdxt = {
'DISABLEEXTENSIONS': False,
'ENABLEEXTENSIONS' : True,
}.get(setting, self.ext_setting)
self.delayexpands.append(delay)
self.ext_settings.append(cmdxt)
self.environments.append(dict(self.environment))
elif head == 'ENDLOCAL' and len(self.environments) > 1:
self.environments.pop()
self.delayexpands.pop()
elif head == 'IF':
then, cmd = self.execute_if(tail)
if not cmd:
ifelse = If.Active | If.Block
if then:
ifelse |= If.Then
continue
elif then:
self.ec = yield EmulatedCommand(cmd)
elif head == 'ELSE':
if If.Else not in ifelse:
raise UnexpectedToken(head)
if If.Then not in ifelse:
if not (cmd := tail.lstrip()):
ifelse |= If.Block
continue
else:
self.ec = yield EmulatedCommand(cmd)
elif head == 'EXIT':
token, tail = self.split_head(tail, toupper=True)
script_only = False
if token == '/B':
script_only = True
token, tail = self.split_head(tail)
try:
exit_code = int(token, 10)
except ValueError:
exit_code = 0
return Exit(exit_code, script_only)
elif head == 'CD' or head == 'CHDIR':
directory, _ = self.split_head(tail, unquote=True, terminator_letters=B'')
self.cwd = directory.rstrip()
elif head == 'PUSHD':
directory, _ = self.split_head(tail, unquote=True, terminator_letters=B'')
self.dirstack.append(self.cwd)
self.cwd = directory.rstrip()
elif head == 'POPD':
try:
self.cwd = self.dirstack.pop()
except IndexError:
pass
elif head == 'GOTO':
label, tail = self.split_head(tail)
if label.startswith(':'):
label = label[1:]
self.ec = yield EmulatedCommand(command)
return Goto(label)
else:
self.ec = yield EmulatedCommand(command)
ifelse = If.Inactive
return Exit()
def _decode(self, data: buf):
if data[:3] == B'\xEF\xBB\xBF':
return codecs.decode(data[3:], 'utf8')
elif data[:2] == B'\xFF\xFE':
return codecs.decode(data[2:], 'utf-16le')
elif data[:2] == B'\xFE\xFF':
return codecs.decode(data[2:], 'utf-16be')
else:
return codecs.decode(data, 'cp1252')
def parse(self, text: str | buf):
self.reset()
if not isinstance(text, str):
text = self._decode(text)
text = '\n'.join(
line.rstrip() for line in re.split(r'[\r\n]+', text.strip()))
utf16 = text.encode('utf-16le')
utf16 = memoryview(utf16).cast('H')
quote = False
caret = False
check = 0
lines = self.code = []
path_to_root = []
def linebreak(k: int):
nonlocal check
line = text[check:k]
check = k + 1
strip = line.strip()
if not strip:
return
lines.append(line)
if strip[0] != ':':
return
label = strip[1:].strip()
if not label:
return
if label[0] == ':':
return
label = label.upper()
index = [len(n) - 1 for n in path_to_root]
index.append(len(lines) - 1)
self.labels[label] = index
for k, char in enumerate(utf16):
if char == _QUOTE:
if quote := not quote:
caret = False
continue
if char == _LINEBREAK:
if caret:
caret = False
else:
linebreak(k)
quote = False
continue
if quote:
continue
if caret:
caret = False
continue
if char == _CARET:
caret = True
continue
if char == _PAREN_OPEN:
linebreak(k)
path_to_root.append(lines)
block = []
lines.append(block)
lines = block
if char == _PAREN_CLOSE:
if not path_to_root:
continue
linebreak(k)
lines = path_to_root.pop()
linebreak(len(text))
Functions
def batchint(expr)
-
Expand source code Browse git
def batchint(expr: str): m = int(expr.startswith('-')) if expr[m:m + 2] in ('0x', '0X'): base = 16 elif expr[m:m + 1] == '0': base = 8 else: base = 10 return int(expr, base)
Classes
class IfEq (lhs, rhs)
-
Abstract base class for generic types.
On Python 3.12 and newer, generic classes implicitly inherit from Generic when they declare a parameter list after the class's name::
class Mapping[KT, VT]: def __getitem__(self, key: KT) -> VT: ... # Etc.
On older versions of Python, however, generic classes have to explicitly inherit from Generic.
After a class has been declared to be generic, it can then be used as follows::
def lookup_name[KT, VT](mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code Browse git
class IfEq(Generic[IntOrStr]): def __init__( self, lhs: IntOrStr, rhs: IntOrStr, ): self.lhs = lhs self.rhs = rhs def EQU(self): return self.lhs == self.rhs def LEQ(self): return self.lhs <= self.rhs def LSS(self): return self.lhs < self.rhs
Ancestors
- typing.Generic
Methods
def EQU(self)
-
Expand source code Browse git
def EQU(self): return self.lhs == self.rhs
def LEQ(self)
-
Expand source code Browse git
def LEQ(self): return self.lhs <= self.rhs
def LSS(self)
-
Expand source code Browse git
def LSS(self): return self.lhs < self.rhs
class EmulatorError (*args, **kwargs)
-
Common base class for all non-exit exceptions.
Expand source code Browse git
class EmulatorError(Exception): pass
Ancestors
- builtins.Exception
- builtins.BaseException
Subclasses
class UnexpectedToken (token)
-
Common base class for all non-exit exceptions.
Expand source code Browse git
class UnexpectedToken(EmulatorError): def __init__(self, token: str) -> None: super().__init__(F'Unexpected token: {token}')
Ancestors
- EmulatorError
- builtins.Exception
- builtins.BaseException
class ExecutionResult
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code Browse git
class ExecutionResult(ABC): @abstractmethod def longjump(self) -> bool: pass
Ancestors
- abc.ABC
Subclasses
Methods
def longjump(self)
-
Expand source code Browse git
@abstractmethod def longjump(self) -> bool: pass
class Exit (code=0, script_only=True)
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code Browse git
class Exit(ExecutionResult): def __init__(self, code: int = 0, script_only: bool = True): self.code = code self.script_only = script_only def longjump(self) -> bool: return not self.script_only
Ancestors
- ExecutionResult
- abc.ABC
Methods
def longjump(self)
-
Expand source code Browse git
def longjump(self) -> bool: return not self.script_only
class Goto (label)
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code Browse git
class Goto(ExecutionResult): def __init__(self, label: str): self.label = label def longjump(self) -> bool: return True
Ancestors
- ExecutionResult
- abc.ABC
Methods
def longjump(self)
-
Expand source code Browse git
def longjump(self) -> bool: return True
class InvalidLabel (label)
-
Common base class for all non-exit exceptions.
Expand source code Browse git
class InvalidLabel(EmulatorError): def __init__(self, label: str): super().__init__(F'The following label was not found: {label}')
Ancestors
- EmulatorError
- builtins.Exception
- builtins.BaseException
class EmulatedCommand (...)
-
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Expand source code Browse git
class EmulatedCommand(str): pass
Ancestors
- builtins.str
class Condition (*args, **kwds)
-
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to 'strict'.
Expand source code Browse git
class Condition(str, enum.Enum): Always = '&' IfOk = '&&' IfNotOk = '||'
Ancestors
- builtins.str
- enum.Enum
Class variables
var Always
var IfOk
var IfNotOk
class If (*args, **kwds)
-
Support for integer-based Flags
Expand source code Browse git
class If(enum.IntFlag): Inactive = 0b0000 Active = 0b0001 Block = 0b0010 Then = 0b0100 Else = 0b1000 def skip_block(self): skip = If.Then not in self if If.Else in self: skip = not skip return skip
Ancestors
- enum.IntFlag
- builtins.int
- enum.ReprEnum
- enum.Flag
- enum.Enum
Class variables
var Inactive
var Active
var Block
var Then
var Else
Methods
def skip_block(self)
-
Expand source code Browse git
def skip_block(self): skip = If.Then not in self if If.Else in self: skip = not skip return skip
class BatchFileEmulator (data, delayed_expansion=False, extensions_enabled=True, extensions_version=2, file_system=None, cwd='C:\\')
-
Expand source code Browse git
class BatchFileEmulator: environments: list[dict[str, str]] code: Block labels: dict[str, list[int]] args: list[str] def __init__( self, data: str | buf, delayed_expansion: bool = False, extensions_enabled: bool = True, extensions_version: int = 2, file_system: dict | None = None, cwd: str = 'C:\\' ): self.delayed_expansion = delayed_expansion self.extensions_version = extensions_version self.extensions_enabled = extensions_enabled self.file_sytem_seed = file_system or {} self.cwd = cwd self.parse(data) @property def cwd(self): return self._cwd @cwd.setter def cwd(self, new: str): new = new.replace('/', '\\') if not new.endswith('\\'): new = F'{new}\\' if not ntpath.isabs(new): raise ValueError(F'Invalid absolute path: {new}') self._cwd = ntpath.normcase(ntpath.normpath(new)) @property def ec(self) -> int: return self.errorlevel @ec.setter def ec(self, value: int | None): ec = value or 0 self.environment['ERRORLEVEL'] = str(ec) self.errorlevel = ec def reset(self): self.labels = {} self.environments = [{}] self.delayexpands = [self.delayed_expansion] self.ext_settings = [self.extensions_enabled] self.file_system = dict(self.file_sytem_seed) self.dirstack = [] self.args = [] self.ec = None def _resolved(self, path: str) -> str: if not ntpath.isabs(path): path = F'{self.cwd}{path}' return ntpath.normcase(ntpath.normpath(path)) def create_file(self, path: str, data: str = ''): self.file_system[self._resolved(path)] = data def append_file(self, path: str, data: str): path = self._resolved(path) if left := self.file_system.get(path, None): data = F'{left}{data}' self.file_system[path] = data def remove_file(self, path: str): self.file_system.pop(self._resolved(path), None) def ingest_file(self, path: str) -> str | None: return self.file_system.get(self._resolved(path)) def exists_file(self, path: str) -> bool: return self._resolved(path) in self.file_system @property def environment(self): return self.environments[-1] @property def delayexpand(self): return self.delayexpands[-1] @property def ext_setting(self): return self.ext_settings[-1] @staticmethod def split_head( expression: str, toupper: bool = False, uncaret: bool = True, unquote: bool = False, terminator_letters: bytes = B'\x20\x09\x0B', terminator_strings: tuple[bytes, ...] = (), ): quote = False caret = False token = array.array("H") utf16 = expression.encode('utf-16le') utf16 = memoryview(utf16).cast('H') t1 = terminator_letters t2 = terminator_strings for k, char in enumerate(utf16): if not quote and not caret: if char in t1 or any(utf16[k:k + len(t)] == t for t in t2): tail = expression[k:] break if char == _QUOTE: quote = not quote if unquote: continue elif quote: pass elif caret: caret = False elif char == _CARET: caret = True if uncaret: continue token.append(char) else: tail = '' head = token.tobytes().decode('utf-16le') if toupper: head = head.upper() return head, tail.lstrip() @overload def expand(self, block: str, delay: bool = False) -> str: ... @overload def expand(self, block: list, delay: bool = False) -> list: ... def expand(self, block: BatchCode, delay: bool = False): def expansion(match: re.Match[str]): name = match.group(1) base = self.environment.get(name.upper(), '') if not (modifier := match.group(2)): return base if '=' in modifier: old, _, new = modifier.partition('=') kwargs = {} if old.startswith('~'): old = old[1:] kwargs.update(count=1) return base.replace(old, new, **kwargs) else: if not modifier.startswith(':~'): raise EmulatorError offset, _, length = modifier[2:].partition(',') offset = batchint(offset) if offset < 0: offset = max(0, len(base) + offset) if length: end = offset + batchint(length) else: end = len(base) return base[offset:end] if delay: pattern = r'!([^!:\n]*)()!' else: pattern = rf'%([^%:\n]*)(:(?:~{formats.integer}(?:,{formats.integer})?|[^=%\n]+=[^%\r\n]*))?%' if isinstance(block, str): return re.sub(pattern, expansion, block) else: return [self.expand(child) for child in block] def execute_set(self, command: str): check, rest = self.split_head(command, toupper=True) if check == '/P': self.ec = yield EmulatedCommand(F'set {command}') return if check == '/A': arithmetic = True command = rest else: arithmetic = False if not command: return command, _ = self.split_head(command, terminator_letters=B'') if command.startswith('"'): # This is how it works based on testing, even if it seems insane. command, _, what = command[1:].rpartition('"') command = command or what if arithmetic: integers = {} updated = {} for name, value in self.environment.items(): try: integers[name] = batchint(value) except ValueError: pass for assignment in command.split(','): assignment = assignment.strip() name, _, expression = assignment.partition('=') expression = cautious_eval_or_default(expression, environment=integers) if expression is not None: integers[name] = expression updated[name] = str(expression) self.environment.update(updated) else: name, _, content = command.partition('=') name = name.upper() content, _ = self.split_head(content, terminator_letters=B'') if not content: self.environment.pop(name, None) else: self.environment[name] = content def execute_if(self, command: str): casefold = False negate = False check, rest = self.split_head(command, toupper=True) if check == '/I': casefold = True command, check, rest = rest, *self.split_head(rest) if check == 'NOT': negate = True command, check, rest = rest, *self.split_head(rest) if check == 'ERRORLEVEL': limit, rest = self.split_head(rest) limit = int(limit.strip(), 10) condition = limit <= self.ec elif check == 'CMDEXTVERSION': limit, rest = self.split_head(rest) limit = int(limit.strip(), 10) condition = limit <= self.extensions_version elif check == 'EXIST': path, rest = self.split_head(rest, unquote=True) condition = self.exists_file(path) elif check == 'DEFINED': name, rest = self.split_head(rest) condition = name.upper() in self.environment else: lhs, rest = self.split_head( command, toupper=False, unquote=True, terminator_strings=(B'==',) ) if rest.startswith('=='): rest = rest[2:].lstrip() rhs, rest = self.split_head(rest, toupper=False, unquote=True) if casefold: lhs = lhs.casefold() rhs = rhs.casefold() condition = lhs == rhs else: cmp, rest = self.split_head(rest) if self.extensions_version < 1: raise UnexpectedToken(cmp) rhs, rest = self.split_head(rest) if cmp == 'GTR': rhs, lhs, cmp = lhs, rhs, 'LSS' if cmp == 'GEQ': rhs, lhs, cmp = lhs, rhs, 'LEQ' if cmp == 'NEQ': negate, cmp = not negate, 'EQU' try: ilh = batchint(lhs) irh = batchint(rhs) except ValueError: pair = IfEq(lhs, rhs) else: pair = IfEq(ilh, irh) if cmp == 'EQU': condition = pair.EQU() elif cmp == 'LSS': condition = pair.LSS() elif cmp == 'LEQ': condition = pair.LEQ() else: raise UnexpectedToken(cmp) if negate: condition = not condition return condition, rest def _commands(self, line: str): quote = False caret = False check = 0 again = None for k, char in enumerate(line): if again: if quote or caret: raise EmulatorError how = None end = None if again == char: how = 2 * again end = k + 1 elif again == Condition.Always: how = again end = k again = None if end is not None and how is not None: cmd = line[check:k - 1] yield cmd.lstrip(), Condition(how) check = end continue if char == '"': quote = not quote continue if char == '\n': raise ValueError if quote: continue if caret: caret = False continue if char == '^': caret = True continue if char in '|&': again = char if (rest := line[check:]) and rest.strip(): yield rest.lstrip(), Condition.Always def _check_condition(self, condition: Condition): if condition == Condition.Always: return True if condition == Condition.IfNotOk: return self.ec != 0 if condition == Condition.IfOk: return self.ec == 0 raise TypeError(condition) def goto(self, index: list[int]) -> tuple[Block, BatchCode, int]: if not index: index = [0] line = 0 code = cursor = self.code for line in index: code, cursor = cursor, cursor[line] assert isinstance(code, list) return code, cursor, line def emulate(self, *args: str) -> Generator[EmulatedCommand, int | None, ExecutionResult]: index = [0] self.args[:] = args while True: block, _, offset = self.goto(index) state = yield from self.emulate_block( block, offset=offset, expand=True, ) if isinstance(state, Goto): label = state.label.upper() if label == 'EOF': return Exit() try: index = self.labels[label] except KeyError as KE: raise InvalidLabel(label) from KE else: continue if isinstance(state, Exit): self.ec = state.code return state raise TypeError(state) def emulate_block( self, block: Block, offset: int = 0, expand: bool = False, ) -> Generator[EmulatedCommand, int | None, ExecutionResult]: it = block if offset <= 0 else itertools.islice(block, offset, None) ifelse = If.Inactive for code in it: if expand: code = self.expand(code) if If.Block in ifelse: if not isinstance(code, list): raise EmulatorError(F'Expected a block while parsing If/Else; {ifelse!r}') if not ifelse.skip_block(): exit = (yield from self.emulate_block(code)) if exit.longjump(): return exit if If.Else in ifelse: ifelse = If.Inactive else: ifelse |= If.Else ifelse &= ~If.Block continue if isinstance(code, list): if ifelse != If.Inactive: raise EmulatorError('Unexpected block in the middle of if/else statement.') exit = (yield from self.emulate_block(code)) if exit.longjump(): return exit continue condition = Condition.Always for command, next_condition in self._commands(code): if not self._check_condition(condition): break condition = next_condition if self.delayexpand: command = self.expand(command, True) head, tail = self.split_head( command, toupper=True, uncaret=False) head = head.lstrip('@') if head == 'SET': yield from self.execute_set(tail) elif head == 'SETLOCAL': setting = tail.strip().upper() delay = { 'DISABLEDELAYEDEXPANSION': False, 'ENABLEDELAYEDEXPANSION' : True, }.get(setting, self.delayexpand) cmdxt = { 'DISABLEEXTENSIONS': False, 'ENABLEEXTENSIONS' : True, }.get(setting, self.ext_setting) self.delayexpands.append(delay) self.ext_settings.append(cmdxt) self.environments.append(dict(self.environment)) elif head == 'ENDLOCAL' and len(self.environments) > 1: self.environments.pop() self.delayexpands.pop() elif head == 'IF': then, cmd = self.execute_if(tail) if not cmd: ifelse = If.Active | If.Block if then: ifelse |= If.Then continue elif then: self.ec = yield EmulatedCommand(cmd) elif head == 'ELSE': if If.Else not in ifelse: raise UnexpectedToken(head) if If.Then not in ifelse: if not (cmd := tail.lstrip()): ifelse |= If.Block continue else: self.ec = yield EmulatedCommand(cmd) elif head == 'EXIT': token, tail = self.split_head(tail, toupper=True) script_only = False if token == '/B': script_only = True token, tail = self.split_head(tail) try: exit_code = int(token, 10) except ValueError: exit_code = 0 return Exit(exit_code, script_only) elif head == 'CD' or head == 'CHDIR': directory, _ = self.split_head(tail, unquote=True, terminator_letters=B'') self.cwd = directory.rstrip() elif head == 'PUSHD': directory, _ = self.split_head(tail, unquote=True, terminator_letters=B'') self.dirstack.append(self.cwd) self.cwd = directory.rstrip() elif head == 'POPD': try: self.cwd = self.dirstack.pop() except IndexError: pass elif head == 'GOTO': label, tail = self.split_head(tail) if label.startswith(':'): label = label[1:] self.ec = yield EmulatedCommand(command) return Goto(label) else: self.ec = yield EmulatedCommand(command) ifelse = If.Inactive return Exit() def _decode(self, data: buf): if data[:3] == B'\xEF\xBB\xBF': return codecs.decode(data[3:], 'utf8') elif data[:2] == B'\xFF\xFE': return codecs.decode(data[2:], 'utf-16le') elif data[:2] == B'\xFE\xFF': return codecs.decode(data[2:], 'utf-16be') else: return codecs.decode(data, 'cp1252') def parse(self, text: str | buf): self.reset() if not isinstance(text, str): text = self._decode(text) text = '\n'.join( line.rstrip() for line in re.split(r'[\r\n]+', text.strip())) utf16 = text.encode('utf-16le') utf16 = memoryview(utf16).cast('H') quote = False caret = False check = 0 lines = self.code = [] path_to_root = [] def linebreak(k: int): nonlocal check line = text[check:k] check = k + 1 strip = line.strip() if not strip: return lines.append(line) if strip[0] != ':': return label = strip[1:].strip() if not label: return if label[0] == ':': return label = label.upper() index = [len(n) - 1 for n in path_to_root] index.append(len(lines) - 1) self.labels[label] = index for k, char in enumerate(utf16): if char == _QUOTE: if quote := not quote: caret = False continue if char == _LINEBREAK: if caret: caret = False else: linebreak(k) quote = False continue if quote: continue if caret: caret = False continue if char == _CARET: caret = True continue if char == _PAREN_OPEN: linebreak(k) path_to_root.append(lines) block = [] lines.append(block) lines = block if char == _PAREN_CLOSE: if not path_to_root: continue linebreak(k) lines = path_to_root.pop() linebreak(len(text))
Class variables
var environments
var code
var labels
var args
Static methods
def split_head(expression, toupper=False, uncaret=True, unquote=False, terminator_letters=b' \t\x0b', terminator_strings=())
-
Expand source code Browse git
@staticmethod def split_head( expression: str, toupper: bool = False, uncaret: bool = True, unquote: bool = False, terminator_letters: bytes = B'\x20\x09\x0B', terminator_strings: tuple[bytes, ...] = (), ): quote = False caret = False token = array.array("H") utf16 = expression.encode('utf-16le') utf16 = memoryview(utf16).cast('H') t1 = terminator_letters t2 = terminator_strings for k, char in enumerate(utf16): if not quote and not caret: if char in t1 or any(utf16[k:k + len(t)] == t for t in t2): tail = expression[k:] break if char == _QUOTE: quote = not quote if unquote: continue elif quote: pass elif caret: caret = False elif char == _CARET: caret = True if uncaret: continue token.append(char) else: tail = '' head = token.tobytes().decode('utf-16le') if toupper: head = head.upper() return head, tail.lstrip()
Instance variables
var cwd
-
Expand source code Browse git
@property def cwd(self): return self._cwd
var ec
-
Expand source code Browse git
@property def ec(self) -> int: return self.errorlevel
var environment
-
Expand source code Browse git
@property def environment(self): return self.environments[-1]
var delayexpand
-
Expand source code Browse git
@property def delayexpand(self): return self.delayexpands[-1]
var ext_setting
-
Expand source code Browse git
@property def ext_setting(self): return self.ext_settings[-1]
Methods
def reset(self)
-
Expand source code Browse git
def reset(self): self.labels = {} self.environments = [{}] self.delayexpands = [self.delayed_expansion] self.ext_settings = [self.extensions_enabled] self.file_system = dict(self.file_sytem_seed) self.dirstack = [] self.args = [] self.ec = None
def create_file(self, path, data='')
-
Expand source code Browse git
def create_file(self, path: str, data: str = ''): self.file_system[self._resolved(path)] = data
def append_file(self, path, data)
-
Expand source code Browse git
def append_file(self, path: str, data: str): path = self._resolved(path) if left := self.file_system.get(path, None): data = F'{left}{data}' self.file_system[path] = data
def remove_file(self, path)
-
Expand source code Browse git
def remove_file(self, path: str): self.file_system.pop(self._resolved(path), None)
def ingest_file(self, path)
-
Expand source code Browse git
def ingest_file(self, path: str) -> str | None: return self.file_system.get(self._resolved(path))
def exists_file(self, path)
-
Expand source code Browse git
def exists_file(self, path: str) -> bool: return self._resolved(path) in self.file_system
def expand(self, block, delay=False)
-
Expand source code Browse git
def expand(self, block: BatchCode, delay: bool = False): def expansion(match: re.Match[str]): name = match.group(1) base = self.environment.get(name.upper(), '') if not (modifier := match.group(2)): return base if '=' in modifier: old, _, new = modifier.partition('=') kwargs = {} if old.startswith('~'): old = old[1:] kwargs.update(count=1) return base.replace(old, new, **kwargs) else: if not modifier.startswith(':~'): raise EmulatorError offset, _, length = modifier[2:].partition(',') offset = batchint(offset) if offset < 0: offset = max(0, len(base) + offset) if length: end = offset + batchint(length) else: end = len(base) return base[offset:end] if delay: pattern = r'!([^!:\n]*)()!' else: pattern = rf'%([^%:\n]*)(:(?:~{formats.integer}(?:,{formats.integer})?|[^=%\n]+=[^%\r\n]*))?%' if isinstance(block, str): return re.sub(pattern, expansion, block) else: return [self.expand(child) for child in block]
def execute_set(self, command)
-
Expand source code Browse git
def execute_set(self, command: str): check, rest = self.split_head(command, toupper=True) if check == '/P': self.ec = yield EmulatedCommand(F'set {command}') return if check == '/A': arithmetic = True command = rest else: arithmetic = False if not command: return command, _ = self.split_head(command, terminator_letters=B'') if command.startswith('"'): # This is how it works based on testing, even if it seems insane. command, _, what = command[1:].rpartition('"') command = command or what if arithmetic: integers = {} updated = {} for name, value in self.environment.items(): try: integers[name] = batchint(value) except ValueError: pass for assignment in command.split(','): assignment = assignment.strip() name, _, expression = assignment.partition('=') expression = cautious_eval_or_default(expression, environment=integers) if expression is not None: integers[name] = expression updated[name] = str(expression) self.environment.update(updated) else: name, _, content = command.partition('=') name = name.upper() content, _ = self.split_head(content, terminator_letters=B'') if not content: self.environment.pop(name, None) else: self.environment[name] = content
def execute_if(self, command)
-
Expand source code Browse git
def execute_if(self, command: str): casefold = False negate = False check, rest = self.split_head(command, toupper=True) if check == '/I': casefold = True command, check, rest = rest, *self.split_head(rest) if check == 'NOT': negate = True command, check, rest = rest, *self.split_head(rest) if check == 'ERRORLEVEL': limit, rest = self.split_head(rest) limit = int(limit.strip(), 10) condition = limit <= self.ec elif check == 'CMDEXTVERSION': limit, rest = self.split_head(rest) limit = int(limit.strip(), 10) condition = limit <= self.extensions_version elif check == 'EXIST': path, rest = self.split_head(rest, unquote=True) condition = self.exists_file(path) elif check == 'DEFINED': name, rest = self.split_head(rest) condition = name.upper() in self.environment else: lhs, rest = self.split_head( command, toupper=False, unquote=True, terminator_strings=(B'==',) ) if rest.startswith('=='): rest = rest[2:].lstrip() rhs, rest = self.split_head(rest, toupper=False, unquote=True) if casefold: lhs = lhs.casefold() rhs = rhs.casefold() condition = lhs == rhs else: cmp, rest = self.split_head(rest) if self.extensions_version < 1: raise UnexpectedToken(cmp) rhs, rest = self.split_head(rest) if cmp == 'GTR': rhs, lhs, cmp = lhs, rhs, 'LSS' if cmp == 'GEQ': rhs, lhs, cmp = lhs, rhs, 'LEQ' if cmp == 'NEQ': negate, cmp = not negate, 'EQU' try: ilh = batchint(lhs) irh = batchint(rhs) except ValueError: pair = IfEq(lhs, rhs) else: pair = IfEq(ilh, irh) if cmp == 'EQU': condition = pair.EQU() elif cmp == 'LSS': condition = pair.LSS() elif cmp == 'LEQ': condition = pair.LEQ() else: raise UnexpectedToken(cmp) if negate: condition = not condition return condition, rest
def goto(self, index)
-
Expand source code Browse git
def goto(self, index: list[int]) -> tuple[Block, BatchCode, int]: if not index: index = [0] line = 0 code = cursor = self.code for line in index: code, cursor = cursor, cursor[line] assert isinstance(code, list) return code, cursor, line
def emulate(self, *args)
-
Expand source code Browse git
def emulate(self, *args: str) -> Generator[EmulatedCommand, int | None, ExecutionResult]: index = [0] self.args[:] = args while True: block, _, offset = self.goto(index) state = yield from self.emulate_block( block, offset=offset, expand=True, ) if isinstance(state, Goto): label = state.label.upper() if label == 'EOF': return Exit() try: index = self.labels[label] except KeyError as KE: raise InvalidLabel(label) from KE else: continue if isinstance(state, Exit): self.ec = state.code return state raise TypeError(state)
def emulate_block(self, block, offset=0, expand=False)
-
Expand source code Browse git
def emulate_block( self, block: Block, offset: int = 0, expand: bool = False, ) -> Generator[EmulatedCommand, int | None, ExecutionResult]: it = block if offset <= 0 else itertools.islice(block, offset, None) ifelse = If.Inactive for code in it: if expand: code = self.expand(code) if If.Block in ifelse: if not isinstance(code, list): raise EmulatorError(F'Expected a block while parsing If/Else; {ifelse!r}') if not ifelse.skip_block(): exit = (yield from self.emulate_block(code)) if exit.longjump(): return exit if If.Else in ifelse: ifelse = If.Inactive else: ifelse |= If.Else ifelse &= ~If.Block continue if isinstance(code, list): if ifelse != If.Inactive: raise EmulatorError('Unexpected block in the middle of if/else statement.') exit = (yield from self.emulate_block(code)) if exit.longjump(): return exit continue condition = Condition.Always for command, next_condition in self._commands(code): if not self._check_condition(condition): break condition = next_condition if self.delayexpand: command = self.expand(command, True) head, tail = self.split_head( command, toupper=True, uncaret=False) head = head.lstrip('@') if head == 'SET': yield from self.execute_set(tail) elif head == 'SETLOCAL': setting = tail.strip().upper() delay = { 'DISABLEDELAYEDEXPANSION': False, 'ENABLEDELAYEDEXPANSION' : True, }.get(setting, self.delayexpand) cmdxt = { 'DISABLEEXTENSIONS': False, 'ENABLEEXTENSIONS' : True, }.get(setting, self.ext_setting) self.delayexpands.append(delay) self.ext_settings.append(cmdxt) self.environments.append(dict(self.environment)) elif head == 'ENDLOCAL' and len(self.environments) > 1: self.environments.pop() self.delayexpands.pop() elif head == 'IF': then, cmd = self.execute_if(tail) if not cmd: ifelse = If.Active | If.Block if then: ifelse |= If.Then continue elif then: self.ec = yield EmulatedCommand(cmd) elif head == 'ELSE': if If.Else not in ifelse: raise UnexpectedToken(head) if If.Then not in ifelse: if not (cmd := tail.lstrip()): ifelse |= If.Block continue else: self.ec = yield EmulatedCommand(cmd) elif head == 'EXIT': token, tail = self.split_head(tail, toupper=True) script_only = False if token == '/B': script_only = True token, tail = self.split_head(tail) try: exit_code = int(token, 10) except ValueError: exit_code = 0 return Exit(exit_code, script_only) elif head == 'CD' or head == 'CHDIR': directory, _ = self.split_head(tail, unquote=True, terminator_letters=B'') self.cwd = directory.rstrip() elif head == 'PUSHD': directory, _ = self.split_head(tail, unquote=True, terminator_letters=B'') self.dirstack.append(self.cwd) self.cwd = directory.rstrip() elif head == 'POPD': try: self.cwd = self.dirstack.pop() except IndexError: pass elif head == 'GOTO': label, tail = self.split_head(tail) if label.startswith(':'): label = label[1:] self.ec = yield EmulatedCommand(command) return Goto(label) else: self.ec = yield EmulatedCommand(command) ifelse = If.Inactive return Exit()
def parse(self, text)
-
Expand source code Browse git
def parse(self, text: str | buf): self.reset() if not isinstance(text, str): text = self._decode(text) text = '\n'.join( line.rstrip() for line in re.split(r'[\r\n]+', text.strip())) utf16 = text.encode('utf-16le') utf16 = memoryview(utf16).cast('H') quote = False caret = False check = 0 lines = self.code = [] path_to_root = [] def linebreak(k: int): nonlocal check line = text[check:k] check = k + 1 strip = line.strip() if not strip: return lines.append(line) if strip[0] != ':': return label = strip[1:].strip() if not label: return if label[0] == ':': return label = label.upper() index = [len(n) - 1 for n in path_to_root] index.append(len(lines) - 1) self.labels[label] = index for k, char in enumerate(utf16): if char == _QUOTE: if quote := not quote: caret = False continue if char == _LINEBREAK: if caret: caret = False else: linebreak(k) quote = False continue if quote: continue if caret: caret = False continue if char == _CARET: caret = True continue if char == _PAREN_OPEN: linebreak(k) path_to_root.append(lines) block = [] lines.append(block) lines = block if char == _PAREN_CLOSE: if not path_to_root: continue linebreak(k) lines = path_to_root.pop() linebreak(len(text))