'''
Cython automatic extension builder/manager
Inspired by IPython's Cython cell magics, see:
https://github.com/ipython/ipython/blob/master/IPython/extensions/cythonmagic.py
'''
import glob
import importlib.util
import os
import shutil
import sys
import time
import hashlib
from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext
import numpy
try:
import Cython
import Cython.Compiler as Cython_Compiler
import Cython.Build as Cython_Build
from Cython.Utils import get_cython_cache_dir as base_cython_cache_dir
except ImportError:
Cython = None
from brian2.utils.logger import std_silent, get_logger
from brian2.utils.stringtools import deindent
from brian2.utils.filelock import FileLock
from brian2.core.preferences import prefs
__all__ = ['cython_extension_manager']
logger = get_logger(__name__)
[docs]def get_cython_cache_dir():
cache_dir = prefs.codegen.runtime.cython.cache_dir
if cache_dir is None and Cython is not None:
cache_dir = os.path.join(base_cython_cache_dir(), 'brian_extensions')
return cache_dir
[docs]def get_cython_extensions():
return {'.pyx', '.pxd', '.pyd', '.cpp', '.c', '.so', '.o', '.o.d', '.lock',
'.dll', '.obj', '.exp', '.lib'}
[docs]class CythonExtensionManager(object):
def __init__(self):
self._code_cache = {}
[docs] def create_extension(self, code, force=False, name=None,
define_macros=None,
include_dirs=None,
library_dirs=None,
runtime_library_dirs=None,
extra_compile_args=None,
extra_link_args=None,
libraries=None,
compiler=None,
sources=None,
owner_name='',
):
if sources is None:
sources = []
self._simplify_paths()
if Cython is None:
raise ImportError('Cython is not available')
code = deindent(code)
lib_dir = get_cython_cache_dir()
if '~' in lib_dir:
lib_dir = os.path.expanduser(lib_dir)
try:
os.makedirs(lib_dir)
except OSError:
if not os.path.exists(lib_dir):
raise IOError("Couldn't create Cython cache directory '%s', try setting the "
"cache directly with prefs.codegen.runtime.cython.cache_dir." % lib_dir)
numpy_version = '.'.join(numpy.__version__.split('.')[:2]) # Only use major.minor version
key = code, sys.version_info, sys.executable, Cython.__version__, numpy_version
if force:
# Force a new module name by adding the current time to the
# key which is hashed to determine the module name.
key += time.time(),
if key in self._code_cache:
return self._code_cache[key]
if name is not None:
module_name = name#py3compat.unicode_to_str(args.name)
else:
module_name = "_cython_magic_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
if owner_name:
logger.diagnostic('"{owner_name}" using Cython module "{module_name}"'.format(owner_name=owner_name,
module_name=module_name))
module_path = os.path.join(lib_dir, module_name + self.so_ext)
if prefs['codegen.runtime.cython.multiprocess_safe']:
lock = FileLock(os.path.join(lib_dir, module_name + '.lock'))
with lock:
module = self._load_module(module_path,
define_macros=define_macros,
include_dirs=include_dirs,
library_dirs=library_dirs,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=libraries,
code=code,
lib_dir=lib_dir,
module_name=module_name,
runtime_library_dirs=runtime_library_dirs,
compiler=compiler,
key=key,
sources=sources)
return module
else:
return self._load_module(module_path,
define_macros=define_macros,
include_dirs=include_dirs,
library_dirs=library_dirs,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=libraries,
code=code,
lib_dir=lib_dir,
module_name=module_name,
runtime_library_dirs=runtime_library_dirs,
compiler=compiler,
key=key,
sources=sources)
@property
def so_ext(self):
"""The extension suffix for compiled modules."""
try:
return self._so_ext
except AttributeError:
self._so_ext = self._get_build_extension().get_ext_filename('')
return self._so_ext
def _clear_distutils_mkpath_cache(self):
"""clear distutils mkpath cache
prevents distutils from skipping re-creation of dirs that have been removed
"""
try:
from distutils.dir_util import _path_created
except ImportError:
pass
else:
_path_created.clear()
def _get_build_extension(self, compiler=None):
self._clear_distutils_mkpath_cache()
dist = Distribution()
config_files = dist.find_config_files()
try:
config_files.remove('setup.cfg')
except ValueError:
pass
dist.parse_config_files(config_files)
build_extension = build_ext(dist)
if compiler is not None:
build_extension.compiler = compiler
build_extension.finalize_options()
return build_extension
def _load_module(self, module_path, define_macros, include_dirs, library_dirs,
extra_compile_args, extra_link_args, libraries, code,
lib_dir, module_name, runtime_library_dirs, compiler,
key, sources):
have_module = os.path.isfile(module_path)
if not have_module:
if define_macros is None:
define_macros = []
if include_dirs is None:
include_dirs = []
if library_dirs is None:
library_dirs = []
if extra_compile_args is None:
extra_compile_args = []
if extra_link_args is None:
extra_link_args = []
if libraries is None:
libraries = []
c_include_dirs = include_dirs
if 'numpy' in code:
import numpy
c_include_dirs.append(numpy.get_include())
# TODO: We should probably have a special folder just for header
# files that are shared between different codegen targets
import brian2.synapses as synapses
synapses_dir = os.path.dirname(synapses.__file__)
c_include_dirs.append(synapses_dir)
pyx_file = os.path.join(lib_dir, module_name + '.pyx')
# ignore Python 3 unicode stuff for the moment
#pyx_file = py3compat.cast_bytes_py2(pyx_file, encoding=sys.getfilesystemencoding())
#with io.open(pyx_file, 'w', encoding='utf-8') as f:
# f.write(code)
with open(pyx_file, 'w') as f:
f.write(code)
for source in sources:
if not source.lower().endswith('.pyx'):
raise ValueError('Additional Cython source files need to '
'have an .pyx ending')
# Copy source and header file (if present) to library directory
shutil.copyfile(source, os.path.join(lib_dir,
os.path.basename(source)))
name_without_ext = os.path.splitext(os.path.basename(source))[0]
header_name = name_without_ext + '.pxd'
if os.path.exists(os.path.join(os.path.dirname(source), header_name)):
shutil.copyfile(os.path.join(os.path.dirname(source), header_name),
os.path.join(lib_dir, header_name))
final_sources = [os.path.join(lib_dir, os.path.basename(source))
for source in sources]
extension = Extension(
name=module_name,
sources=[pyx_file],
define_macros=define_macros,
include_dirs=c_include_dirs,
library_dirs=library_dirs,
runtime_library_dirs=runtime_library_dirs,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=libraries,
language='c++')
build_extension = self._get_build_extension(compiler=compiler)
try:
opts = dict(
quiet=True,
annotate=False,
force=True,
)
# suppresses the output on stdout
with std_silent():
build_extension.extensions = Cython_Build.cythonize([extension] + final_sources, **opts)
build_extension.build_temp = os.path.dirname(pyx_file)
build_extension.build_lib = lib_dir
build_extension.run()
if prefs['codegen.runtime.cython.delete_source_files']:
# we can delete the source files to save disk space
cpp_file = os.path.join(lib_dir, module_name + '.cpp')
try:
os.remove(pyx_file)
os.remove(cpp_file)
temp_dir = os.path.join(lib_dir, os.path.dirname(pyx_file)[1:], module_name + '.*')
for fname in glob.glob(temp_dir):
os.remove(fname)
except (OSError, IOError) as ex:
logger.debug('Deleting Cython source files failed with error: %s' % str(ex))
except Cython_Compiler.Errors.CompileError:
return
# Temporarily insert the Cython directory to the Python path so that
# code importing from an external module that was declared via
# sources works
sys.path.insert(0, lib_dir)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.path.pop(0)
self._code_cache[key] = module
return module
def _simplify_paths(self):
if 'lib' in os.environ:
os.environ['lib'] = simplify_path_env_var(os.environ['lib'])
if 'include' in os.environ:
os.environ['include'] = simplify_path_env_var(os.environ['include'])
cython_extension_manager = CythonExtensionManager()
[docs]def simplify_path_env_var(path):
allpaths = path.split(os.pathsep)
knownpaths = set()
uniquepaths = []
for p in allpaths:
if p not in knownpaths:
knownpaths.add(p)
uniquepaths.append(p)
return os.pathsep.join(uniquepaths)