Source code for neuralnet_pytorch.utils.layer_utils

import numpy as np
import numbers
from torch._six import container_abcs
import sympy as sp

from . import root_logger

__all__ = ['validate', 'no_dim_change_op', 'add_simple_repr', 'add_custom_repr', 'deprecated', 'get_non_none']


def _make_input_shape(m, n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable):
            return x

        return sp.symbols('b:{}'.format(m), iteger=True) + (x,) + sp.symbols('x:{}'.format(n), integer=True)
    return parse


[docs]def validate(func): """ A decorator to make sure output shape is a tuple of ``int`` s. """ def wrapper(self): shape = func(self) if shape is None: return None if isinstance(shape, numbers.Number): return int(shape) out = [None if x is None or (isinstance(x, numbers.Number) and np.isnan(x)) else int(x) if isinstance(x, numbers.Number) else x for x in shape] return tuple(out) return wrapper
[docs]def no_dim_change_op(cls): """ A decorator to overwrite :attr:`~neuralnet_pytorch.layers._LayerMethod.output_shape` to an op that does not change the tensor shape. :param cls: a subclass of :class:`~neuralnet_pytorch.layers.Module`. """ @validate def output_shape(self): return None if self.input_shape is None else tuple(self.input_shape) cls.output_shape = property(output_shape) return cls
[docs]def add_simple_repr(cls): """ A decorator to add a simple repr to the designated class. :param cls: a subclass of :class:`~neuralnet_pytorch.layers.Module`. """ def _repr(self): return super(cls, self).__repr__() + ' -> {}'.format(self.output_shape) setattr(cls, '__repr__', _repr) return cls
[docs]def add_custom_repr(cls): """ A decorator to add a custom repr to the designated class. User should define extra_repr for the decorated class. :param cls: a subclass of :class:`~neuralnet_pytorch.layers.Module`. """ def _repr(self): return self.__class__.__name__ + '({}) -> {}'.format(self.extra_repr(), self.output_shape) setattr(cls, '__repr__', _repr) return cls
def deprecated(new_func, version): def _deprecated(func): """prints out a deprecation warning""" def func_wrapper(*args, **kwargs): root_logger.warning('%s is deprecated and will be removed in version %s. Use %s instead.' % (func.__name__, version, new_func.__name__), exc_info=True) return func(*args, **kwargs) return func_wrapper return _deprecated
[docs]def get_non_none(array): """ Gets the first item that is not ``None`` from the given array. :param array: an arbitrary array that is iterable. :return: the first item that is not ``None``. """ assert isinstance(array, container_abcs.Iterable) try: e = next(item for item in array if item is not None) except StopIteration: e = None return e