pooling

A collection of pooling layers with a common API.

swem.models.pooling.PoolingConfig

Configuration for pooling layers.

swem.models.pooling.SwemPoolingLayer

Base class for all pooling layers.

swem.models.pooling.AttentionPooling

Pooling with an attention mechanism.

swem.models.pooling.HierarchicalPooling

Hierarchical Pooling layer (see Baselines need more love ).

swem.models.pooling.MaxPooling

Simple max pooling layer.

swem.models.pooling.MeanPooling

Simple mean pooling layer.

PoolingConfig

class swem.models.pooling.PoolingConfig(type: Literal['HierarchicalPooling', 'AttentionPooling', 'MaxPooling', 'MeanPooling'], window_size: int | None = None, input_dim: int | None = None)

Configuration for pooling layers.

Parameters
  • type (Literal[('HierarchicalPooling', 'AttentionPooling', 'MaxPooling', 'MeanPooling')]) –

  • window_size (int | None) –

  • input_dim (int | None) –

Return type

None

SwemPoolingLayer

class swem.models.pooling.SwemPoolingLayer

Base class for all pooling layers.

Return type

None

forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor

The pooling computation. This should be overridden by subclasses.

Parameters
  • input (torch.FloatTensor) –

  • mask (torch.FloatTensor | None) –

Return type

torch.FloatTensor

static from_config(config: PoolingConfig | dict[str, str | int]) SwemPoolingLayer

Construct a pooling layer from the config.

Instead of a pooling config the user can also provide a dictionary representation of the config.

Raises

NotImplementedError – If the type specified by ‘type’ is unknown.

Returns

A pooling layer defined by the config.

Return type

SwemPoolingLayer

Parameters

config (PoolingConfig | dict[str, str | int]) –

AttentionPooling

class swem.models.pooling.AttentionPooling(input_dim: int)

Pooling with an attention mechanism.

Pools along the sequence dimension by computing a weighted sum whose weights are computed as a small feed forward network (with output size 1) applied to the input vectors, followed by a softmax.

Takes an optional mask telling us which inputs to ignore for the softmax.

Parameters

input_dim (int) – The size of the input vectors.

Shapes:
  • input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)

  • mask: \((\text{batch_size}, \text{enc_dim})\)

  • output: \((\text{batch_size}, \text{enc_dim})\)

forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor

Pooling forward pass.

Parameters
  • input (torch.FloatTensor) –

  • mask (torch.FloatTensor | None) –

Return type

torch.FloatTensor

HierarchicalPooling

class swem.models.pooling.HierarchicalPooling(window_size: int)

Hierarchical Pooling layer (see Baselines need more love ).

First mean pooling along the sequence dimension over windows of the given size, then max pooling along the sequence dimension.

The mask input is ignored; it is only there for compatibility.

Parameters

window_size (int) – Size of the pooling window in the mean pooling step.

Shapes:
  • input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)

  • mask: \((\text{batch_size}, \text{enc_dim})\)

  • output: \((\text{batch_size}, \text{enc_dim})\)

forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor

Pooling forward pass.

Parameters
  • input (torch.FloatTensor) –

  • mask (torch.FloatTensor | None) –

Return type

torch.FloatTensor

MaxPooling

class swem.models.pooling.MaxPooling

Simple max pooling layer.

Max pooling along the sequence dimension.

The mask input is ignored; it is only there for compatibility.

Shapes:
  • input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)

  • mask: \((\text{batch_size}, \text{enc_dim})\)

  • output: \((\text{batch_size}, \text{enc_dim})\)

Return type

None

forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor

Pooling forward pass.

Parameters
  • input (torch.FloatTensor) –

  • mask (torch.FloatTensor | None) –

Return type

torch.FloatTensor

MeanPooling

class swem.models.pooling.MeanPooling

Simple mean pooling layer.

Mean pooling along the sequence dimension; ignoring inputs according to the given mask.

Shapes:
  • input: \((\text{batch_size}, \text{seq_len}, \text{enc_dim})\)

  • mask: \((\text{batch_size}, \text{enc_dim})\)

  • output: \((\text{batch_size}, \text{enc_dim})\)

Return type

None

forward(input: torch.FloatTensor, mask: torch.FloatTensor | None = None) torch.FloatTensor

Pooling forward pass.

Parameters
  • input (torch.FloatTensor) –

  • mask (torch.FloatTensor | None) –

Return type

torch.FloatTensor