Module refinery.lib.inline
Implements programmatic inlining, specifically for the units in refinery.units.blockwise
.
Expand source code Browse git
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Implements programmatic inlining, specifically for the units in `refinery.units.blockwise`.
"""
from __future__ import annotations
from typing import Any, Callable, Generator, Optional, Type, TypeVar, Union
import ast
import inspect
import functools
import sys
from ast import (
Assign,
BinOp,
BitAnd,
Bytes,
Call,
Constant,
Expr,
For,
FunctionDef,
GeneratorExp,
Load,
Name,
NodeTransformer,
Num,
Return,
Store,
Str,
Yield,
arg,
comprehension,
)
_R = TypeVar('_R')
class inline(Callable[..., _R]):
def __init__(self, function: Callable[..., _R]):
self._function = function
functools.update_wrapper(self, function)
def __call__(self, *args, **kwargs):
return self._function(*args, **kwargs)
class PassAsConstant:
def __init__(self, value):
self.value = value
def getsource(f):
return inspect.cleandoc(F'\n{inspect.getsource(f)}')
class ArgumentCountMismatch(ValueError):
pass
class NoFunctionDefinitionFound(ValueError):
pass
def iterspread(
method: Callable[..., _R],
iterator,
*inline_args,
mask: Optional[int] = None
) -> Callable[..., Generator[_R, None, None]]:
"""
This function receives an arbitrary callable `method`, a primary iterator called `iterator`, and
an arbitrary number of additional arguments, collected in the `inline_args` variable. The function
will essentially turn this:
def method(self, a, b):
return a + b
into this:
def iterspread_method(self):
for _var_a in _arg_a:
_var_b = next(_arg_b)
yield _var_a + _var_b
where `_arg_a` and `_arg_b` are closure variables that are bound to the primary iterator, and the
single element of `inline_args`, respectively. If one of the elements in `inline_args` is a constant,
then this constant will instead be set initially in front of the loop:
def iterspread_method(self):
_var_b = 5
for _var_a in _arg_a:
yield _var_a + _var_b
Spreading the application of `method` like this provides a high performance increase over making a
function call to `method` in each step of the iteration.
"""
code = ast.parse(getsource(method))
closure_vars = inspect.getclosurevars(method)
context = closure_vars.nonlocals
reserved_names = set(context)
reserved_names.update(closure_vars.globals)
reserved_names.update(closure_vars.builtins)
try:
function_head = next(node for node in ast.walk(code) if isinstance(node, FunctionDef))
except StopIteration:
raise NoFunctionDefinitionFound
function_name = None
def as_arg(name: str): return F'_arg_{name}'
def as_var(name: str): return F'_var_{name}'
def as_tmp(name: str): return F'_tmp_{name}'
def apply_node_transformation(cls: Type[NodeTransformer]):
nonlocal code
code = ast.fix_missing_locations(cls().visit(code))
if sys.version_info >= (3, 8):
def constant(value):
return Constant(value=value)
else:
@functools.singledispatch
def constant(value: Union[int, str, bytes]):
raise NotImplementedError(F'The type {type(value).__name__} is not supported for inlining')
@constant.register # noqa
def _(value: bytes): return Bytes(s=value)
@constant.register
def _(value: int): return Num(n=value)
@constant.register
def _(value: str): return Str(s=value)
@apply_node_transformation
class _(NodeTransformer):
def visit_Name(self, node: Name) -> Any:
if node.id in reserved_names:
return node
node.id = as_var(node.id)
return node
@apply_node_transformation
class _(NodeTransformer):
def visit_FunctionDef(self, node: FunctionDef):
nonlocal function_name, context
if node is not function_head:
return node
function_body = []
function_name = node.name
function_args = [arg.arg for arg in node.args.args]
inlined_start = 1
if inspect.ismethod(method):
inlined_start += 1
iterator_name = function_args[inlined_start - 1]
function_args[:inlined_start] = []
arity = len(function_args)
try:
vararg = as_arg(node.args.vararg.arg)
except Exception:
if arity != len(inline_args):
raise ArgumentCountMismatch
else:
context[vararg] = inline_args[arity:]
for name, value in zip(function_args, inline_args):
targets = [Name(id=as_var(name), ctx=Store())]
if isinstance(value, PassAsConstant):
context[as_var(name)] = value.value
continue
if isinstance(value, (int, str, bytes)):
context[as_var(name)] = value
continue
context[as_arg(name)] = value
function_body.append(Assign(
targets=targets,
value=Call(
func=Name(id='next', ctx=Load()),
args=[Name(id=as_arg(name), ctx=Load())],
keywords=[]
)))
if node.args.vararg:
name = node.args.vararg.arg
function_body.append(Assign(
targets=[
Name(id=as_var(name), ctx=Store())
],
value=Call(
func=Name(id='tuple', ctx=Load()),
args=[GeneratorExp(
elt=Call(
func=Name(id='next', ctx=Load()),
args=[Name(id=as_tmp(name), ctx=Load())],
keywords=[]
),
generators=[comprehension(
is_async=0,
target=Name(id=as_tmp(name), ctx=Store()),
iter=Name(id=as_arg(name), ctx=Load()),
ifs=[]
)]
)],
keywords=[]
)
))
function_body.extend(node.body)
context[as_arg(iterator_name)] = iterator
function_body = [For(
target=Name(id=as_var(iterator_name), ctx=Store()),
iter=Name(id=as_arg(iterator_name), ctx=Load()),
body=function_body,
orelse=[]
)]
node.body = function_body
node.args.args = [arg(arg=as_var('self'))]
node.args.vararg = None
node.decorator_list = []
return node
@apply_node_transformation
class _(NodeTransformer):
def visit_Return(self, node: Return) -> Any:
value = node.value
if mask is not None:
value = BinOp(left=value, op=BitAnd(), right=constant(mask))
return Expr(value=Yield(value=value))
bin = compile(code, '<inlined>', 'exec', optimize=2)
exec(bin, context, context)
return context[function_name]
Functions
def getsource(f)
-
Expand source code Browse git
def getsource(f): return inspect.cleandoc(F'\n{inspect.getsource(f)}')
def iterspread(method, iterator, *inline_args, mask=None)
-
This function receives an arbitrary callable
method
, a primary iterator callediterator
, and an arbitrary number of additional arguments, collected in theinline_args
variable. The function will essentially turn this:def method(self, a, b): return a + b
into this:
def iterspread_method(self): for _var_a in _arg_a: _var_b = next(_arg_b) yield _var_a + _var_b
where
_arg_a
and_arg_b
are closure variables that are bound to the primary iterator, and the single element ofinline_args
, respectively. If one of the elements ininline_args
is a constant, then this constant will instead be set initially in front of the loop:def iterspread_method(self): _var_b = 5 for _var_a in _arg_a: yield _var_a + _var_b
Spreading the application of
method
like this provides a high performance increase over making a function call tomethod
in each step of the iteration.Expand source code Browse git
def iterspread( method: Callable[..., _R], iterator, *inline_args, mask: Optional[int] = None ) -> Callable[..., Generator[_R, None, None]]: """ This function receives an arbitrary callable `method`, a primary iterator called `iterator`, and an arbitrary number of additional arguments, collected in the `inline_args` variable. The function will essentially turn this: def method(self, a, b): return a + b into this: def iterspread_method(self): for _var_a in _arg_a: _var_b = next(_arg_b) yield _var_a + _var_b where `_arg_a` and `_arg_b` are closure variables that are bound to the primary iterator, and the single element of `inline_args`, respectively. If one of the elements in `inline_args` is a constant, then this constant will instead be set initially in front of the loop: def iterspread_method(self): _var_b = 5 for _var_a in _arg_a: yield _var_a + _var_b Spreading the application of `method` like this provides a high performance increase over making a function call to `method` in each step of the iteration. """ code = ast.parse(getsource(method)) closure_vars = inspect.getclosurevars(method) context = closure_vars.nonlocals reserved_names = set(context) reserved_names.update(closure_vars.globals) reserved_names.update(closure_vars.builtins) try: function_head = next(node for node in ast.walk(code) if isinstance(node, FunctionDef)) except StopIteration: raise NoFunctionDefinitionFound function_name = None def as_arg(name: str): return F'_arg_{name}' def as_var(name: str): return F'_var_{name}' def as_tmp(name: str): return F'_tmp_{name}' def apply_node_transformation(cls: Type[NodeTransformer]): nonlocal code code = ast.fix_missing_locations(cls().visit(code)) if sys.version_info >= (3, 8): def constant(value): return Constant(value=value) else: @functools.singledispatch def constant(value: Union[int, str, bytes]): raise NotImplementedError(F'The type {type(value).__name__} is not supported for inlining') @constant.register # noqa def _(value: bytes): return Bytes(s=value) @constant.register def _(value: int): return Num(n=value) @constant.register def _(value: str): return Str(s=value) @apply_node_transformation class _(NodeTransformer): def visit_Name(self, node: Name) -> Any: if node.id in reserved_names: return node node.id = as_var(node.id) return node @apply_node_transformation class _(NodeTransformer): def visit_FunctionDef(self, node: FunctionDef): nonlocal function_name, context if node is not function_head: return node function_body = [] function_name = node.name function_args = [arg.arg for arg in node.args.args] inlined_start = 1 if inspect.ismethod(method): inlined_start += 1 iterator_name = function_args[inlined_start - 1] function_args[:inlined_start] = [] arity = len(function_args) try: vararg = as_arg(node.args.vararg.arg) except Exception: if arity != len(inline_args): raise ArgumentCountMismatch else: context[vararg] = inline_args[arity:] for name, value in zip(function_args, inline_args): targets = [Name(id=as_var(name), ctx=Store())] if isinstance(value, PassAsConstant): context[as_var(name)] = value.value continue if isinstance(value, (int, str, bytes)): context[as_var(name)] = value continue context[as_arg(name)] = value function_body.append(Assign( targets=targets, value=Call( func=Name(id='next', ctx=Load()), args=[Name(id=as_arg(name), ctx=Load())], keywords=[] ))) if node.args.vararg: name = node.args.vararg.arg function_body.append(Assign( targets=[ Name(id=as_var(name), ctx=Store()) ], value=Call( func=Name(id='tuple', ctx=Load()), args=[GeneratorExp( elt=Call( func=Name(id='next', ctx=Load()), args=[Name(id=as_tmp(name), ctx=Load())], keywords=[] ), generators=[comprehension( is_async=0, target=Name(id=as_tmp(name), ctx=Store()), iter=Name(id=as_arg(name), ctx=Load()), ifs=[] )] )], keywords=[] ) )) function_body.extend(node.body) context[as_arg(iterator_name)] = iterator function_body = [For( target=Name(id=as_var(iterator_name), ctx=Store()), iter=Name(id=as_arg(iterator_name), ctx=Load()), body=function_body, orelse=[] )] node.body = function_body node.args.args = [arg(arg=as_var('self'))] node.args.vararg = None node.decorator_list = [] return node @apply_node_transformation class _(NodeTransformer): def visit_Return(self, node: Return) -> Any: value = node.value if mask is not None: value = BinOp(left=value, op=BitAnd(), right=constant(mask)) return Expr(value=Yield(value=value)) bin = compile(code, '<inlined>', 'exec', optimize=2) exec(bin, context, context) return context[function_name]
Classes
class inline (function)
-
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
Expand source code Browse git
class inline(Callable[..., _R]): def __init__(self, function: Callable[..., _R]): self._function = function functools.update_wrapper(self, function) def __call__(self, *args, **kwargs): return self._function(*args, **kwargs)
Ancestors
- collections.abc.Callable
- typing.Generic
class PassAsConstant (value)
-
Expand source code Browse git
class PassAsConstant: def __init__(self, value): self.value = value
class ArgumentCountMismatch (*args, **kwargs)
-
Inappropriate argument value (of correct type).
Expand source code Browse git
class ArgumentCountMismatch(ValueError): pass
Ancestors
- builtins.ValueError
- builtins.Exception
- builtins.BaseException
class NoFunctionDefinitionFound (*args, **kwargs)
-
Inappropriate argument value (of correct type).
Expand source code Browse git
class NoFunctionDefinitionFound(ValueError): pass
Ancestors
- builtins.ValueError
- builtins.Exception
- builtins.BaseException