Linear¶
The module for the ensemble extended linear process.
- class supertransformerlib.Linear.Linear(input_shape: Union[Tensor, List[int], int], output_shape: Union[Tensor, List[int], int], ensemble_shapes: Optional[Union[Tensor, List[int], int]] = None)¶
A Linear layer allowing head-dependent linear processing of data from shape to shape. JIT is supported as an instance.
An instance is made by providing a list of head_shapes, an input_shape tuple, an output_shape tuple.
This is then used to initialize a head dependent linear remap from input shape to output shape. That will then be accessed through the instance call
It is expected that the input format will be in the form of
[…, heads, input_shape]
Returning something of format
[…, heads, output_shape]
Letting the head_shape parameter be none will disable it, resulting in broadcasting. Input shape, output shape, and head_shapes may all be just an integer, in which case it is assumed only a single dimension is involved.