pooling
A collection of pooling layers with a common API.
Configuration for pooling layers. |
|
Base class for all pooling layers. |
|
Pooling with an attention mechanism. |
|
Hierarchical Pooling layer (see Baselines need more love ). |
|
Simple max pooling layer. |
|
Simple mean pooling layer. |
PoolingConfig
- class swem.models.pooling.PoolingConfig(type: Literal['HierarchicalPooling', 'AttentionPooling', 'MaxPooling', 'MeanPooling'], window_size: int | None = None, input_dim: int | None = None)
Configuration for pooling layers.
- Parameters
type (Literal[('HierarchicalPooling', 'AttentionPooling', 'MaxPooling', 'MeanPooling')]) –
window_size (int | None) –
input_dim (int | None) –
- Return type
None
SwemPoolingLayer
- class swem.models.pooling.SwemPoolingLayer
Base class for all pooling layers.
- Return type
None
- forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor
The pooling computation. This should be overridden by subclasses.
- Parameters
input (torch.FloatTensor) –
mask (torch.FloatTensor | None) –
- Return type
torch.FloatTensor
- static from_config(config: PoolingConfig | dict[str, str | int]) SwemPoolingLayer
Construct a pooling layer from the config.
Instead of a pooling config the user can also provide a dictionary representation of the config.
- Raises
NotImplementedError – If the type specified by ‘type’ is unknown.
- Returns
A pooling layer defined by the config.
- Return type
- Parameters
config (PoolingConfig | dict[str, str | int]) –
AttentionPooling
- class swem.models.pooling.AttentionPooling(input_dim: int)
Pooling with an attention mechanism.
Pools along the sequence dimension by computing a weighted sum whose weights are computed as a small feed forward network (with output size 1) applied to the input vectors, followed by a softmax.
Takes an optional mask telling us which inputs to ignore for the softmax.
- Parameters
input_dim (int) – The size of the input vectors.
- Shapes:
input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)
mask: \((\text{batch_size}, \text{enc_dim})\)
output: \((\text{batch_size}, \text{enc_dim})\)
- forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor
Pooling forward pass.
- Parameters
input (torch.FloatTensor) –
mask (torch.FloatTensor | None) –
- Return type
torch.FloatTensor
HierarchicalPooling
- class swem.models.pooling.HierarchicalPooling(window_size: int)
Hierarchical Pooling layer (see Baselines need more love ).
First mean pooling along the sequence dimension over windows of the given size, then max pooling along the sequence dimension.
The mask input is ignored; it is only there for compatibility.
- Parameters
window_size (int) – Size of the pooling window in the mean pooling step.
- Shapes:
input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)
mask: \((\text{batch_size}, \text{enc_dim})\)
output: \((\text{batch_size}, \text{enc_dim})\)
- forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor
Pooling forward pass.
- Parameters
input (torch.FloatTensor) –
mask (torch.FloatTensor | None) –
- Return type
torch.FloatTensor
MaxPooling
- class swem.models.pooling.MaxPooling
Simple max pooling layer.
Max pooling along the sequence dimension.
The mask input is ignored; it is only there for compatibility.
- Shapes:
input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)
mask: \((\text{batch_size}, \text{enc_dim})\)
output: \((\text{batch_size}, \text{enc_dim})\)
- Return type
None
- forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor
Pooling forward pass.
- Parameters
input (torch.FloatTensor) –
mask (torch.FloatTensor | None) –
- Return type
torch.FloatTensor
MeanPooling
- class swem.models.pooling.MeanPooling
Simple mean pooling layer.
Mean pooling along the sequence dimension; ignoring inputs according to the given mask.
- Shapes:
input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)
mask: \((\text{batch_size}, \text{enc_dim})\)
output: \((\text{batch_size}, \text{enc_dim})\)
- Return type
None
- forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor
Pooling forward pass.
- Parameters
input (torch.FloatTensor) –
mask (torch.FloatTensor | None) –
- Return type
torch.FloatTensor