"""Contains the FluentModule class"""

import pytypeutils as tus
import typing
import torch
import torch.nn as nn
import operator
from functools import reduce

[docs]class Reshape(nn.Module): """Reshapes the input to match the given shape, using view. This preserves the first dimension which is assumed to be the batch dimension. :Example: .. code-block:: python import torchluent import torch a = torchluent.Reshape(28*28) data = torch.randn(5, 28, 28) reshaped = a(data) print(reshaped.shape) # torch.Size[5, 784] :ivar tuple[int] shape: the new shape for the input """ def __init__(self, *args): super().__init__() tus.check(args=(args, tuple)) tus.check_listlike(args=(args, int, (1, None))) self.shape = args
[docs] def forward(self, x): """Changes the view of x to the desired shape""" real_new_shape = [x.shape[0]] real_new_shape.extend(self.shape) return x.view(real_new_shape)
[docs] def extra_repr(self): return ', '.join(str(i) for i in self.shape)
[docs]class Transpose(nn.Module): """Transposes two dimensions. Does not effect the batch dimension. :Example: .. code-block:: python import torchluent import torch transposer = torchluent.Transpose(0, 1) data = torch.randn(5, 100, 50) newdata = transposer(data) print(newdata.shape) # torch.Size[5, 50, 100] :ivar int dim1: the first dimension to transpose :ivar int dim2: the second dimension to transpose """ def __init__(self, dim1: int, dim2: int): super().__init__() tus.check(dim1=(dim1, int), dim2=(dim2, int)) if dim1 < 0: raise ValueError(f'dim1={dim1} must be nonnegative') if dim2 < 0: raise ValueError(f'dim2={dim2} must be nonnegative') self.dim1 = dim1 self.dim2 = dim2
[docs] def forward(self, x): return x.tranpose(self.dim1 + 1, self.dim2 + 1)
[docs] def extra_repr(self): return f'{self.dim1}, {self.dim2}'
[docs]class InitListModule(nn.Module): """Initializes a list of states, optionally with the state its passed in. :ivar bool include_first: True to include x in the list, False to make an empty list. """ def __init__(self, include_first: bool): super().__init__() tus.check(include_first=(include_first, bool)) self.include_first = include_first
[docs] def forward(self, x): return x, [x] if self.include_first else x, []
[docs] def extra_repr(self): return f'include_first={self.include_first}'
[docs]class WrapModule(nn.Module): """Wraps a module which is expecting just x, passing the list through it :ivar nn.Module child: the wrapped module """ def __init__(self, child: nn.Module): super().__init__() tus.check(child=(child, nn.Module)) self.child = child
[docs] def forward(self, x_and_arr): return self.child(x_and_arr[0]), x_and_arr[1]
[docs]class SaveStateModule(nn.Module): """Stores the state into the array. """
[docs] def forward(self, x_and_arr): x_and_arr[1].append(x_and_arr[0]) return x_and_arr
[docs]class StrippingModule(nn.Module): """Strips the array from the output of the child :ivar nn.Module child: the child who we are stripping """ def __init__(self, child: nn.Module): super().__init__() tus.check(child=(child, nn.Module)) self.child = child
[docs] def forward(self, x): return self.child(x)[0]
[docs]class FluentModule: """ This constructs torch modules in a fluent-style interface. :Example: .. code-block:: python from torchluent import FluentModule net = ( FluentModule(28*28) .dense(128) .operator('ReLU') .dense(10) .operator('ReLU') .build() ) .. note:: This modules shape and all shape arguments are in practice prefixed by a batch dimension. The batch dimension is not altered by any of these calls, including reshaping, unless otherwise specified. :ivar list[nn.Module] sequence: the actual sequence of modules that we have constructed so far. :ivar tuple[int] shape: the current feature shape :ivar bool is_verbose: if we are currently outputting each function call and the corresponding effects :ivar bool wrapped: if we are currently storing a list of hidden states """ def __init__(self, shape: typing.Tuple[int], assume_wrapped: bool = False): tus.check(shape=(shape, (list, tuple)), assume_wrapped=(assume_wrapped, bool)) tus.check_listlike(shape=(shape, int, (1, None))) for features in shape: if features <= 0: raise ValueError(f'shape={shape} must be positive') self.shape = tuple(shape) self.sequence = [] self.is_verbose = False self.wrapped = assume_wrapped
[docs] def verbose(self) -> 'FluentModule': """Turns on verbose mode, which cases this to output every function call and the resulting shape. :returns: self :rtype: FluentModule """ self.is_verbose = True print(f' {self.shape}') return self
[docs] def silent(self) -> 'FluentModule': """Disables verbose mode :returns: self :rtype: FluentModule """ self.is_verbose = False return self
[docs] def wrap(self, with_input: bool = False) -> 'FluentModule': """Changes the output to the form (x, arr) where an arr is a list of states stored in locations specified with save_state() :param with_input: if True we immediately save_state() :type with_input: bool :returns: self :rtype: FluentModule """ tus.check(with_input=(with_input, bool)) if self.wrapped: raise ValueError('already wrapped') self.wrapped = True self.sequence.append(InitListModule(with_input)) return self
def _wrap(self, mod): return WrapModule(mod) if self.wrapped else mod
[docs] def save_state(self): """Stores the current state into the list for the result. Requires that wrap() has already been called. :returns: self :rtype: FluentModule """ if not self.wrapped: raise ValueError('cannot save_state() without wrap()') self.sequence.append(SaveStateModule()) return self
[docs] def dense(self, out_features: int, bias: bool = True) -> 'FluentModule': """A dense layer, also known as a linear layer or a fully connected layer. A dense layer requires that this already be in flattened form, i.e., len(self.shape) == 1. :param out_features: the number of neurons to project to :param bias: determines if a bias (additive) term is applied to each of the output features :type out_features: int :type bias: bool :returns: self :rtype: FluentModule """ tus.check(out_features=(out_features, int), bias=(bias, bool)) if out_features <= 0: raise ValueError(f'out_features={out_features} must be positive') if len(self.shape) != 1: raise ValueError( f'cannot perform operation {self.shape} -> dense -> ' + f'{out_features} (current shape is not flat). consider ' + 'calling flatten() first') self.sequence.append(self._wrap(nn.Linear(self.shape[0], out_features, bias))) self.shape = (out_features,) if self.is_verbose: print(f' Linear -> {self.shape}') return self
[docs] def reshape(self, shape: typing.Tuple[int]) -> 'FluentModule': """Reshapes the data to the specified shape. Must correspond to the same total number of features. .. note:: The batch dimension is preserved. :param shape: the new shape for the data :type shape: tuple[int] :returns: self :rtype: FluentModule """ tus.check(shape=(shape, (list, tuple))) tus.check_listlike(shape=(shape, int, (1, None))) for features in shape: if features <= 0: raise ValueError(f'shape={shape} must be positive') old_num_features = reduce(operator.mul, self.shape) new_num_features = reduce(operator.mul, shape) if old_num_features != new_num_features: raise ValueError( f'cannot view {self.shape} as {shape}: expected ' + f'{old_num_features} but got {new_num_features}') self.sequence.append(self._wrap(Reshape(*shape))) self.shape = tuple(shape) if self.is_verbose: print(f' Reshape -> {self.shape}') return self
[docs] def flatten(self) -> 'FluentModule': """Reshapes this such that the data has only one dimension. .. note:: The batch dimension is preserved. :returns: self :rtype: FluentModule """ return self.reshape((reduce(operator.mul, self.shape),))
[docs] def transpose(self, dim1: int, dim2: int) -> 'FluentModule': """Transposes the two specified dimensions, where dimension 0 is the first dimension after the batch dimension (i.e., really index 0 in self.shape). :Example: from torchluent import FluentModule import torch net = FluentModule((1, 12, 24)).transpose(0, 2).build() inp = torch.randn((5, 1, 12, 24)) out = net(inp) print(out.shape) # torch.Size[5, 12, 24, 1] :returns: self :rtype: FluentModule """ tus.check(dim1=(dim1, int), dim2=(dim2, int)) if not 0 <= dim1 < len(self.shape) or not 0 <= dim2 < len(self.shape): raise ValueError(f'cannot transpose {dim1} and {dim2} for ' + f'shape {self.shape}') self.sequence.append(self._wrap(Transpose(dim1, dim2))) newshape = list(self.shape) tmp = newshape[dim1] newshape[dim1] = newshape[dim2] newshape[dim2] = tmp self.shape = list(newshape) if self.is_verbose: print(f' Transpose[{dim1}, {dim2}] -> {self.shape}') return self
[docs] def operator(self, oper, *args, **kwargs) -> 'FluentModule': """An operator is some operation which does not change the shape of the data. The operator may be specified as a string, in which it should be a module in torch.nn, or it may be the module itself which has not yet be initialized (i.e. 'ReLU' or nn.ReLU but not nn.ReLU()) :Example: .. code-block:: python from torchluent import FluentModule net = ( FluentModule(28*28) .dense(10) .operator('LeakyReLU', negative_slope=0.05) .build() ) :param oper: the name of the operator or a callable which returns one :param args: passed to the operator :param kwargs: passed to the operator :returns: self :rtype: FluentModule """ if isinstance(oper, str): if not hasattr(nn, oper): raise ValueError(f'torch.nn has no attribute {oper}') oper = getattr(nn, oper) mod = oper(*args, **kwargs) if self.is_verbose: print(f' {type(mod).__name__}') tus.check(**{'oper(*args, **kwargs)': (mod, nn.Module)}) self.sequence.append(self._wrap(mod)) return self
[docs] def then(self, module, *args, **kwargs) -> 'FluentModule': """Applies a generic torch module transformation. To determine the output shape, this just runs some data through the module. If the module is a string then it it is assumed to be the name of an attribute in torch.nn, and it is initialized with the specified arguments. :param module: the module that should modify the data :rtype module: union[nn.Module, str, type] :returns: self :rtype: FluentModule """ if isinstance(module, str): if not hasattr(nn, module): raise ValueError(f'torch.nn has no attribute {module}') module = getattr(nn, module) if not isinstance(module, nn.Module): module = module(*args, **kwargs) module.eval() with torch.no_grad(): data = torch.randn(self.shape).unsqueeze(0) output = module(data) tus.check(output=(output, torch.Tensor)) if output.shape[0] != 1: raise ValueError('module killed batch dimension; ' + f'output shape: {output.shape}') new_shape = list(output.shape) new_shape.pop(0) self.sequence.append(self._wrap(module)) self.shape = tuple(new_shape) if self.is_verbose: print(f' {type(module).__name__} -> {self.shape}') return self
[docs] def then_with(self, dims, mod, *args, **kwargs) -> 'FluentModule': """This applies the given nn.Module or string for an attribute in nn with the given dimensions passed as inputs. dims should either be a single number, which is treated like a tuple of a single element, or a tuple of numbers, which is treated as if each element is (i, num) where i is the index, or a tuple of (arg_index, num). Our current shape is injected into args such that for each pair (arg_index, num) in dims, args[arg_index] = self.shape[num]. This allows for an extremely generic interface for modules which do not have a dedicated function for them. :Example: .. code-block:: python from torchluent import FluentModule net = ( FluentModule((1, 7, 7)) .verbose() .then_with(0, 'ConvTranspose2d', 16, kernel_size=2, stride=2, padding=2) .operator('LeakyReLU') .then_with(0, 'ConvTranspose2d', 32, kernel_size=2, stride=2, padding=2) .operator('LeakyReLU') .then_with(0, 'ConvTranspose2d', 1, kernel_size=3, stride=2, padding=2) .operator('LeakyReLU') .build() ) :ivar dims: one of int, tuple[int], and tuple[tuple[int, int]]. each element is treated as if by (arg_index, num) where num is the dimension in self.shape that corresponds to args[arg_index] :ivar mod: either a str (for an attribute in nn) or a callable which returns a module. :returns: self :rtype: FluentModule """ if isinstance(dims, int): dims = (dims,) tus.check(dims=(dims, (list, tuple))) dims = list(dims) for i in range(len(dims)): if isinstance(dims[i], int): dims[i] = (i, dims[i]) tus.check(**{f'dims[{i}]': (dims[i], tuple)}) tus.check_listlike(**{f'dims[{i}]': (dims[i], int, 2)}) if dims[i][0] < 0: raise ValueError(f'dims[{i}][0] = {dims[i][0]} ' + 'should be nonnegative') if dims[i][0] >= len(args) + len(dims): raise ValueError(f'dims[{i}][0] = {dims[i][0]} requires more ' + 'arguments than were specified') if not 0 <= dims[i][1] < len(self.shape): raise ValueError(f'dims[{i}][1]={dims[i][1]} is not valid for ' + f'the current shape {self.shape}') if len(set(arg_index for arg_index, num in dims)) != len(dims): raise ValueError(f'arg_index must be unique in dims={dims}') if isinstance(mod, str): if not hasattr(nn, mod): raise ValueError(f'no module {mod} in torch.nn') mod = getattr(nn, mod) dims.sort(key=lambda x: x[0]) newargs = [] newargs.extend(args) for arg_index, shape_index in dims: newargs.insert(arg_index, self.shape[shape_index]) return self.then(mod(*newargs, **kwargs))
[docs] def conv1d(self, *args, **kwargs) -> 'FluentModule': """Applies a 1d convolution to the current data. The current shape should be in the form (channels, length). This accepts all the same arguments as nn.Conv1d exception for in_channels which it will calculate from the current shape. .. seealso:: `torch.nn.Conv1d <>`_ :returns: self :rtype: FluentModule """ if len(self.shape) != 2: raise ValueError(f'cannot perform conv1d on shape {self.shape} - ' + 'expected shape (channels, length)') return self.then_with(0, 'Conv1d', *args, **kwargs)
[docs] def conv2d(self, *args, **kwargs) -> 'FluentModule': """Applies a convolution to the current data. The current shape should be in the form (channels, height, width). This accepts all the same arguments as nn.Conv2d except for in_channels, which it will calculate from the current shape. .. seealso:: `torch.nn.Conv2d <>`_ :returns: self :rtype: FluentModule """ if len(self.shape) != 3: raise ValueError(f'cannot perform conv2d on shape {self.shape} - ' + 'expected shape (channels, height, width)') return self.then_with(0, 'Conv2d', *args, **kwargs)
[docs] def conv3d(self, *args, **kwargs) -> 'FluentModule': """Applies a convolution to the current data. The current shape should be in the form (channels, depth, height, width). This accepts all the same arguments as nn.Conv3d except for in_channels, which it will calculate from the current shape. .. seealso:: `torch.nn.Conv3d <>`_ :returns: self :rtype: FluentModule """ if len(self.shape) != 4: raise ValueError(f'cannot perform conv3d on shape {self.shape} - ' + 'expected shape ' + '(channels, depth, height, width)') return self.then_with(0, 'Conv3d', *args, **kwargs)
[docs] def maxpool1d(self, *args, **kwargs) -> 'FluentModule': """The arguments and keyword arguments are identical to MaxPool1d .. seealso:: `torch.nn.MaxPool1d <>`_ :returns: self :rtype: FluentModule """ if len(self.shape) != 2: raise ValueError(f'cannot perform maxpool1d on shape {self.shape} - ' + 'expected shape (channels, length)') return self.then('MaxPool1d', *args, **kwargs)
[docs] def maxpool2d(self, *args, **kwargs) -> 'FluentModule': """The arguments and keyword arguments are identical to MaxPool2d .. seealso:: `torch.nn.MaxPool2d <>`_ :returns: self :rtype: FluentModule """ if len(self.shape) != 3: raise ValueError(f'cannot perform maxpool2d on shape ' + f'{self.shape} - expected shape ' + '(channels, height, width)') return self.then('MaxPool2d', *args, **kwargs)
[docs] def maxpool3d(self, *args, **kwargs) -> 'FluentModule': """The arguments and keyword arguments are identical to MaxPool3d .. seealso:: `torch.nn.MaxPool3d <>`_ :returns: self :rtype: FluentModule """ if len(self.shape) != 4: raise ValueError(f'cannot perform maxpool3d on shape ' + f'{self.shape} - expected shape ' + '(channels, depth, height, width)') return self.then('MaxPool3d', *args, **kwargs)
[docs] def build(self, with_stripped=False) -> nn.Module: """Constructs the actual torch module created through other invocations to this instance. :param with_stripped: if True, wrap() must have been called and the output changes to (net, stripped_net). :type with_stripped: bool :returns: a ready-to-use torch module :rtype: nn.Module """ tus.check(with_stripped=(with_stripped, bool)) if with_stripped and not self.wrapped: raise ValueError('cannot strip unless already wrapped') res = nn.Sequential(*self.sequence) if with_stripped: return res, StrippingModule(res) return res