Source code for brian2.units.unitsafefunctions

"""
Unit-aware replacements for numpy functions.
"""
from functools import wraps

import pkg_resources
import numpy as np

from .fundamentalunits import (Quantity, wrap_function_dimensionless,
                               wrap_function_remove_dimensions,
                               fail_for_dimension_mismatch, is_dimensionless,
                               DIMENSIONLESS)

__all__ = [
         'log', 'log10', 'exp',
         'sin', 'cos', 'tan',
         'arcsin', 'arccos', 'arctan',
         'sinh', 'cosh', 'tanh',
         'arcsinh', 'arccosh', 'arctanh',
         'diagonal', 'ravel', 'trace', 'dot',
         'where',
         'ones_like', 'zeros_like',
         'arange', 'linspace'
         ]

[docs]def where(condition, *args, **kwds): # pylint: disable=C0111 if len(args) == 0: # nothing to do return np.where(condition, *args, **kwds) elif len(args) == 2: # check that x and y have the same dimensions fail_for_dimension_mismatch(args[0], args[1], 'x and y need to have the same dimensions') if is_dimensionless(args[0]): return np.where(condition, *args, **kwds) else: # as both arguments have the same unit, just use the first one's dimensionless_args = [np.asarray(arg) for arg in args] return Quantity.with_dimensions(np.where(condition, *dimensionless_args), args[0].dimensions) else: # illegal number of arguments, let numpy take care of this return np.where(condition, *args, **kwds)
where.__doc__ = np.where.__doc__ # Functions that work on dimensionless quantities only sin = wrap_function_dimensionless(np.sin) sinh = wrap_function_dimensionless(np.sinh) arcsin = wrap_function_dimensionless(np.arcsin) arcsinh = wrap_function_dimensionless(np.arcsinh) cos = wrap_function_dimensionless(np.cos) cosh = wrap_function_dimensionless(np.cosh) arccos = wrap_function_dimensionless(np.arccos) arccosh = wrap_function_dimensionless(np.arccosh) tan = wrap_function_dimensionless(np.tan) tanh = wrap_function_dimensionless(np.tanh) arctan = wrap_function_dimensionless(np.arctan) arctanh = wrap_function_dimensionless(np.arctanh) log = wrap_function_dimensionless(np.log) log10 = wrap_function_dimensionless(np.log10) exp = wrap_function_dimensionless(np.exp) ones_like = wrap_function_remove_dimensions(np.ones_like) zeros_like = wrap_function_remove_dimensions(np.zeros_like)
[docs]def wrap_function_to_method(func): ''' Wraps a function so that it calls the corresponding method on the Quantities object (if called with a Quantities object as the first argument). All other arguments are left untouched. ''' @wraps(func) def f(x, *args, **kwds): # pylint: disable=C0111 if isinstance(x, Quantity): return getattr(x, func.__name__)(*args, **kwds) else: # no need to wrap anything return func(x, *args, **kwds) f.__doc__ = func.__doc__ f.__name__ = func.__name__ return f
@wraps(np.arange)
[docs]def arange(*args, **kwargs): # arange has a bit of a complicated argument structure unfortunately # we leave the actual checking of the number of arguments to numpy, though # default values start = kwargs.pop('start', 0) step = kwargs.pop('step', 1) stop = kwargs.pop('stop', None) if len(args) == 1: if stop is not None: raise TypeError('Duplicate definition of "stop"') stop = args[0] elif len(args) == 2: if start != 0: raise TypeError('Duplicate definition of "start"') if stop is not None: raise TypeError('Duplicate definition of "stop"') start, stop = args elif len(args) == 3: if start != 0: raise TypeError('Duplicate definition of "start"') if stop is not None: raise TypeError('Duplicate definition of "stop"') if step != 1: raise TypeError('Duplicate definition of "step"') start, stop, step = args elif len(args) > 3: raise TypeError('Need between 1 and 3 non-keyword arguments') if stop is None: raise TypeError('Missing stop argument.') fail_for_dimension_mismatch(start, stop, error_message=('Start value {start} and stop ' 'value {stop} have to have the ' 'same units.'), start=start, stop=stop) fail_for_dimension_mismatch(stop, step, error_message=('Stop value {stop} and step ' 'value {step} have to have the ' 'same units.'), stop=stop, step=step) dim = getattr(stop, 'dim', DIMENSIONLESS) return Quantity(np.arange(start=np.asarray(start), stop=np.asarray(stop), step=np.asarray(step), **kwargs), dim=dim, copy=False)
@wraps(np.linspace)
[docs]def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): fail_for_dimension_mismatch(start, stop, error_message=('Start value {start} and stop ' 'value {stop} have to have the ' 'same units.'), start=start, stop=stop) dim = getattr(start, 'dim', DIMENSIONLESS) if pkg_resources.parse_version(np.__version__) < pkg_resources.parse_version('1.9.0'): if dtype is not None: raise TypeError('The "dtype" argument needs numpy >= 1.9.0') result = np.linspace(np.asarray(start), np.asarray(stop), num=num, endpoint=endpoint, retstep=retstep) else: result = np.linspace(np.asarray(start), np.asarray(stop), num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) return Quantity(result, dim=dim, copy=False)
# these functions discard subclass info -- maybe a bug in numpy? ravel = wrap_function_to_method(np.ravel) diagonal = wrap_function_to_method(np.diagonal) trace = wrap_function_to_method(np.trace) dot = wrap_function_to_method(np.dot) # This is a very minor detail: setting the __module__ attribute allows the # automatic reference doc generation mechanism to attribute the functions to # this module. Maybe also helpful for IDEs and other code introspection tools. sin.__module__ = __name__ sinh.__module__ = __name__ arcsin.__module__ = __name__ arcsinh.__module__ = __name__ cos.__module__ = __name__ cosh.__module__ = __name__ arccos.__module__ = __name__ arccosh.__module__ = __name__ tan.__module__ = __name__ tanh.__module__ = __name__ arctan.__module__ = __name__ arctanh.__module__ = __name__ log.__module__ = __name__ exp.__module__ = __name__ ravel.__module__ = __name__ diagonal.__module__ = __name__ trace.__module__ = __name__ dot.__module__ = __name__ arange.__module__ = __name__ linspace.__module__ = __name__
[docs]def setup(): ''' Setup function for doctests (used by nosetest). We do not want to test this module's docstrings as they are inherited from numpy. ''' from nose import SkipTest raise SkipTest('Do not test numpy docstrings')