import ast
import inspect
from brian2.utils.stringtools import deindent, get_identifiers, indent
from .rendering import NodeRenderer
__all__ = [
"AbstractCodeFunction",
"abstract_code_from_function",
"extract_abstract_code_functions",
"substitute_abstract_code_functions",
]
[docs]
class AbstractCodeFunction:
"""
The information defining an abstract code function
Has attributes corresponding to initialisation parameters
Parameters
----------
name : str
The function name.
args : list of str
The arguments to the function.
code : str
The abstract code string consisting of the body of the function less
the return statement.
return_expr : str or None
The expression returned, or None if there is nothing returned.
"""
def __init__(self, name, args, code, return_expr):
self.name = name
self.args = args
self.code = code
self.return_expr = return_expr
def __str__(self):
s = (
f"def {self.name}({', '.join(self.args)}):\n{indent(self.code)}\n return"
f" {self.return_expr}\n"
)
return s
__repr__ = __str__
[docs]
def abstract_code_from_function(func):
"""
Converts the body of the function to abstract code
Parameters
----------
func : function, str or ast.FunctionDef
The function object to convert. Note that the arguments to the
function are ignored.
Returns
-------
func : AbstractCodeFunction
The corresponding abstract code function
Raises
------
SyntaxError
If unsupported features are used such as if statements or indexing.
"""
if callable(func):
code = deindent(inspect.getsource(func))
funcnode = ast.parse(code, mode="exec").body[0]
elif isinstance(func, str):
funcnode = ast.parse(func, mode="exec").body[0]
elif func.__class__ is ast.FunctionDef:
funcnode = func
else:
raise TypeError("Unsupported function type")
if funcnode.args.vararg is not None:
raise SyntaxError("No support for variable number of arguments")
if funcnode.args.kwarg is not None:
raise SyntaxError("No support for arbitrary keyword arguments")
if len(funcnode.args.defaults):
raise SyntaxError("No support for default values in functions")
nodes = funcnode.body
nr = NodeRenderer()
lines = []
return_expr = None
for node in nodes:
if node.__class__ is ast.Return:
return_expr = nr.render_node(node.value)
break
else:
lines.append(nr.render_node(node))
abstract_code = "\n".join(lines)
args = [arg.arg for arg in funcnode.args.args]
name = funcnode.name
return AbstractCodeFunction(name, args, abstract_code, return_expr)
[docs]
class VarRewriter(ast.NodeTransformer):
"""
Rewrites all variable names in names by prepending pre
"""
def __init__(self, pre):
self.pre = pre
[docs]
def visit_Name(self, node):
return ast.Name(id=self.pre + node.id, ctx=node.ctx)
[docs]
def visit_Call(self, node):
args = [self.visit(arg) for arg in node.args]
return ast.Call(
func=ast.Name(id=node.func.id, ctx=ast.Load()),
args=args,
keywords=[],
)
[docs]
class FunctionRewriter(ast.NodeTransformer):
"""
Inlines a function call using temporary variables
numcalls is the number of times the function rewriter has been called so
far, this is used to make sure that when recursively inlining there is no
name aliasing. The substitute_abstract_code_functions ensures that this is
kept up to date between recursive runs.
The pre attribute is the set of lines to be inserted above the currently
being processed line, i.e. the inline code.
The visit method returns the current line processed so that the function
call is replaced with the output of the inlining.
"""
def __init__(self, func, numcalls=0):
self.func = func
self.numcalls = numcalls
self.pre = []
self.suspend = False
[docs]
def visit_Call(self, node):
# we suspend operations during an inlining operation, then resume
# afterwards, see below, so we only ever try to expand one inline
# function call at a time, i.e. no f(f(x)). This case is handled
# by the recursion.
if self.suspend:
return node
# We only work with the function we're provided
if node.func.id != self.func.name:
return node
# Suspend while processing arguments (no recursion)
self.suspend = True
args = [self.visit(arg) for arg in node.args]
self.suspend = False
# The basename is used for function-local variables
basename = f"_inline_{self.func.name}_{str(self.numcalls)}"
# Assign all the function-local variables
for argname, arg in zip(self.func.args, args):
newpre = ast.Assign(
targets=[ast.Name(id=f"{basename}_{argname}", ctx=ast.Store())],
value=arg,
)
self.pre.append(newpre)
# Rewrite the lines of code of the function using the names defined
# above
vr = VarRewriter(f"{basename}_")
for funcline in ast.parse(self.func.code).body:
self.pre.append(vr.visit(funcline))
# And rewrite the return expression
return_expr = vr.visit(ast.parse(self.func.return_expr, mode="eval").body)
self.pre.append(
ast.Assign(
targets=[ast.Name(id=basename, ctx=ast.Store())], value=return_expr
)
)
# Finally we replace the function call with the output of the inlining
newnode = ast.Name(id=basename)
self.numcalls += 1
return newnode
[docs]
def substitute_abstract_code_functions(code, funcs):
"""
Performs inline substitution of all the functions in the code
Parameters
----------
code : str
The abstract code to make inline substitutions into.
funcs : list, dict or set of AbstractCodeFunction
The function substitutions to use, note in the case of a dict, the
keys are ignored and the function name is used.
Returns
-------
code : str
The code with inline substitutions performed.
"""
if isinstance(funcs, (list, set)):
newfuncs = dict()
for f in funcs:
newfuncs[f.name] = f
funcs = newfuncs
code = deindent(code)
lines = ast.parse(code, mode="exec").body
# This is a slightly nasty hack, but basically we just check by looking at
# the existing identifiers how many inline operations have already been
# performed by previous calls to this function
ids = get_identifiers(code)
funcstarts = {}
for func in funcs.values():
subids = {id for id in ids if id.startswith(f"_inline_{func.name}_")}
subids = {id.replace(f"_inline_{func.name}_", "") for id in subids}
alli = []
for subid in subids:
p = subid.find("_")
if p > 0:
subid = subid[:p]
i = int(subid)
alli.append(i)
if len(alli) == 0:
i = 0
else:
i = max(alli) + 1
funcstarts[func.name] = i
# Now we rewrite all the lines, replacing each line with a sequence of
# lines performing the inlining
newlines = []
for line in lines:
for func in funcs.values():
rw = FunctionRewriter(func, funcstarts[func.name])
line = rw.visit(line)
newlines.extend(rw.pre)
funcstarts[func.name] = rw.numcalls
newlines.append(line)
# Now we render to a code string
nr = NodeRenderer()
newcode = "\n".join(nr.render_node(line) for line in newlines)
# We recurse until no changes in the code to ensure that all functions
# are expanded if one function refers to another, etc.
if newcode == code:
return newcode
else:
return substitute_abstract_code_functions(newcode, funcs)