import ast
from collections import namedtuple
from brian2.utils.stringtools import deindent
__all__ = ["abstract_code_dependencies"]
[docs]
def get_read_write_funcs(parsed_code):
allids = set()
read = set()
write = set()
funcs = set()
for node in ast.walk(parsed_code):
if node.__class__ is ast.Name:
allids.add(node.id)
if node.ctx.__class__ is ast.Store:
write.add(node.id)
elif node.ctx.__class__ is ast.Load:
read.add(node.id)
else:
raise SyntaxError
elif node.__class__ is ast.Call:
funcs.add(node.func.id)
read = read - funcs
# check that there's no funky stuff going on with functions
if funcs.intersection(write):
raise SyntaxError("Cannot assign to functions in abstract code")
return allids, read, write, funcs
[docs]
def abstract_code_dependencies(code, known_vars=None, known_funcs=None):
"""
Analyses identifiers used in abstract code blocks
Parameters
----------
code : str
The abstract code block.
known_vars : set
The set of known variable names.
known_funcs : set
The set of known function names.
Returns
-------
results : namedtuple with the following fields
``all``
The set of all identifiers that appear in this code block,
including functions.
``read``
The set of values that are read, excluding functions.
``write``
The set of all values that are written to.
``funcs``
The set of all function names.
``known_all``
The set of all identifiers that appear in this code block and
are known.
``known_read``
The set of known values that are read, excluding functions.
``known_write``
The set of known values that are written to.
``known_funcs``
The set of known functions that are used.
``unknown_read``
The set of all unknown variables whose values are read. Equal
to ``read-known_vars``.
``unknown_write``
The set of all unknown variables written to. Equal to
``write-known_vars``.
``unknown_funcs``
The set of all unknown function names, equal to
``funcs-known_funcs``.
``undefined_read``
The set of all unknown variables whose values are read before they
are written to. If this set is nonempty it usually indicates an
error, since a variable that is read should either have been
defined in the code block (in which case it will appear in
``newly_defined``) or already be known.
``newly_defined``
The set of all variable names which are newly defined in this
abstract code block.
"""
if known_vars is None:
known_vars = set()
if known_funcs is None:
known_funcs = set()
if not isinstance(known_vars, set):
known_vars = set(known_vars)
if not isinstance(known_funcs, set):
known_funcs = set(known_funcs)
code = deindent(code, docstring=True)
parsed_code = ast.parse(code, mode="exec")
# Get the list of all variables that are read from and written to,
# ignoring the order
allids, read, write, funcs = get_read_write_funcs(parsed_code)
# Now check if there are any values that are unknown and read before
# they are written to
defined = known_vars.copy()
newly_defined = set()
undefined_read = set()
for line in parsed_code.body:
_, cur_read, cur_write, _ = get_read_write_funcs(line)
undef = cur_read - defined
undefined_read |= undef
newly_defined |= (cur_write - defined) - undefined_read
defined |= cur_write
# Return the results as a named tuple
results = dict(
all=allids,
read=read,
write=write,
funcs=funcs,
known_all=allids.intersection(known_vars.union(known_funcs)),
known_read=read.intersection(known_vars),
known_write=write.intersection(known_vars),
known_funcs=funcs.intersection(known_funcs),
unknown_read=read - known_vars,
unknown_write=write - known_vars,
unknown_funcs=funcs - known_funcs,
undefined_read=undefined_read,
newly_defined=newly_defined,
)
return namedtuple("AbstractCodeDependencies", list(results.keys()))(**results)