Building Blocks

Curated Transformers provides building blocks to create your own transformer models.

Embedding Layers

These modules implement full embedding layers.

class curated_transformers.layers.EmbeddingDropouts(embed_output_dropout=Identity(), proj_output_dropout=Identity())

Dropouts used in a transformer embedding layer.

By default, all the dropouts are disabled by setting the dropout to the Torch Identity module. Therefore, only dropouts that are needed have to be set.

Parameters:
  • embed_output_dropout (Module) – Dropout of the embeddings.

  • proj_output_dropout (Module) – Dropout of the output of the projection layer.

class curated_transformers.layers.EmbeddingLayerNorms(embed_output_layer_norm=Identity(), proj_output_layer_norm=Identity())

Layer normalizations used in a transformer embedding layer.

By default, all the normalizations are disabled by setting the layer normalization to the Torch Identity module. Therefore, only normalizations that are needed have to be set.

Parameters:
  • embed_output_layer_norm (Module) – Normalization of the embeddings.

  • proj_output_layer_norm (Module) – Normalization of the output of the projection layer.

class curated_transformers.layers.TransformerEmbeddings(*, dropouts, embedding_width, hidden_width, layer_norms, n_pieces, n_positions, n_types, device=None)

Transformer embeddings layer.

This is a generic transformer embedding layer. The layer always has piece embeddings and can optionally have position embeddings, type embeddings, and a projection of embeddings to the model’s hidden size.

Construct an embeddings layer.

Parameters:
  • dropouts (EmbeddingDropouts) – Dropouts to use in the embeddings layer.

  • embedding_width (int) – Width of the embeddings.

  • hidden_width (int) – Hidden width of the transformer. If this width differs from embedding_width, a projection layer is added to ensure that the output of the embeddings layer has the same width as the transformer.

  • layer_norms (EmbeddingLayerNorms) – Layer norms to use in the embeddings layer.

  • n_pieces (int) – Number of piece embeddings.

  • n_positions (Optional[int]) – Number of position embeddings. Position embeddings are disabled by using None. Position embeddings can be used to inform the model of input order.

  • n_types (Optional[int]) – Number of type embeddings. Type embeddings are disabled by using None. Type embeddings can be used to inform the model of the spans of different sequences in the input.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(piece_ids, *, positions=None, type_ids=None)

Apply the embedding layer to the piece identifiers.

Parameters:
  • piece_ids (Tensor) –

    Piece identifiers to embed.

    Shape: (batch_size, seq_len)

  • positions (Optional[Tensor]) –

    Positional with which to fetch the positional embeddings for the sequences.

    Shape: (batch_size, seq_len)

  • type_ids (Optional[Tensor]) –

    Type identifiers to indicate the spans of different sequences in the input. Useful when performing tasks like sequence classification and question answering.

    Shape: (batch_size, seq_len)

Return type:

Tensor

Encoder/Decoder Layers

These modules implement full encoder/decoder layers.

class curated_transformers.layers.TransformerDropouts(attn_output_dropout=Identity(), ffn_output_dropout=Identity(), parallel_attn_dropout=Identity())

Dropouts used in a transformer layer.

By default, all the dropouts are disabled by setting the dropout to the Torch Identity module. Therefore, only dropouts that are needed have to be set.

Parameters:
  • attn_output_dropout (Module) – Dropout of the output of the attention layer.

  • ffn_output_dropout (Module) – Dropout of the output of the attention layer.

  • parallel_attn_dropout (Module) – Dropout after summing the attention and feed-forward layers. Only used when parallel attention is enabled.

classmethod layer_output_dropouts(p)

Utility method to construct attention and feed-forward layer dropouts.

Parameters:

p (float) – Dropout probability.

Return type:

TransformerDropouts

Returns:

Dropouts of attention and feed-forward layers set to p.

classmethod parallel_attention_dropout(p)

Utility method to construct parallel attention dropout.

Parameters:

p (float) – Dropout probability.

Return type:

TransformerDropouts

Returns:

Dropouts of parallel attention set to p.

class curated_transformers.layers.TransformerLayerNorms(attn_input_layer_norm=Identity(), attn_residual_layer_norm=Identity(), ffn_input_layer_norm=Identity(), ffn_residual_layer_norm=Identity())

Layer normalizations used in a transformer layer.

By default, all the normalizations are disabled by setting the layer normalization to the Torch Identity module. Therefore, only normalizations that are needed have to be set.

Parameters:
  • attn_input_layer_norm (Module) – Normalization of the input to the attention layer.

  • attn_residual_layer_norm (Module) – Normalization of the output of the attention layer after the residual connection.

  • ffn_input_layer_norm (Module) – Normalization of the input to the feed-forward layer.

  • ffn_residual_layer_norm (Module) – Normalization of the output of the feed-forward layer after the residual connection.

class curated_transformers.layers.DecoderLayer(*, attention_layer, dropouts, feed_forward_layer, layer_norms, use_parallel_attention)

Transformer decoder layer (Vaswani et al., 2017).

Construct a transformer layer.

Parameters:
  • attention_layer (SelfAttention) – The attention layer to use in the transformer layer.

  • dropouts (TransformerDropouts) – Dropouts to use in the transformer layer.

  • feed_forward_layer (PointwiseFeedForward) – The pointwise feed-forward layer to use in the transformer layer.

  • layer_norms (TransformerLayerNorms) – Layer norms to use in the layer.

  • use_parallel_attention (bool) – Use parallel attention.

forward(input, attention_mask, *, cache=None, positions=None, store_cache=False)

Apply the decoder layer to the given piece hidden representations.

Parameters:
  • input (Tensor) –

    Hidden representations to apply the layer to.

    Shape: (batch_size, seq_len, width)

  • attention_mask (AttentionMask) – Attention mask. Sequence elements for which the corresponding mask element is set to False are ignored during attention calculation.

  • cache (Optional[KeyValueCache]) – Key/value cache to avoid recomputing key/value representations for tokens that were previously seen.

  • positions (Optional[Tensor]) – Input positions. Positions are needed to look up rotary embeddings. Normally, these positions are calculated automatically. But if the positions deviate for some reason, they can be provided through this argument.

  • store_cache (bool) – Whether to cache the key/value representations for future reuse.

Return type:

Tuple[Tensor, Optional[KeyValueCache]]

Returns:

Layer output and the key/value cache.

Shape: (batch_size, seq_len, width)

class curated_transformers.layers.EncoderLayer(*, attention_layer, dropouts, feed_forward_layer, layer_norms, use_parallel_attention)

Transformer encoder layer (Vaswani et al., 2017).

Construct a transformer layer.

Parameters:
  • attention_layer (SelfAttention) – The attention layer to use in the transformer layer.

  • dropouts (TransformerDropouts) – Dropouts to use in the transformer layer.

  • feed_forward_layer (PointwiseFeedForward) – The pointwise feed-forward layer to use in the transformer layer.

  • layer_norms (TransformerLayerNorms) – Layer norms to use in the layer.

  • use_parallel_attention (bool) – Use parallel attention.

forward(input, attention_mask)

Apply the encoder layer to the given piece hidden representations.

Parameters:
  • input (Tensor) –

    Hidden representations to apply the layer to.

    Shape: (batch_size, seq_len, width)

  • attention_mask (AttentionMask) – Attention mask. Sequence elements for which the corresponding mask element is set to False are ignored during attention calculation.

Return type:

Tuple[Tensor, Optional[KeyValueCache]]

Returns:

Layer output and the key/value cache.

Shape: (batch_size, seq_len, width)

Attention

These modules and their helper classes implement the Transformer attention mechanism.

class curated_transformers.layers.QkvMode(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)

How the query, key and value projections are handled in the self-attention layer.

MERGED_SPLIT_AFTER = 2

MERGED_SPLIT_AFTER - Use a merged projection for query, key and value, and split heads after splitting the query, key and value representations.

MERGED_SPLIT_BEFORE = 1

MERGED_SPLIT_BEFORE - Use a merged projection for query, key and value, and split heads before splitting the query, key and value representations. This ordering is incompatible with head sharing in keys or values.

SEPARATE = 0

SEPARATE - Use separate projections for query, key and value.

class curated_transformers.layers.QkvSplit

Query, key, value splitting strategies.

After the input projection of the attention layer, we have an array with shape (batch_size, seq_len, n_heads * head_width) where n_heads is the sum of the number of query, key, and value heads. We need to split up the array into separate arrays for query, key, and value heads.

Subclasses of this class implement different splitting strategies.

abstract split(*, projection, head_width, n_query_heads, n_key_value_heads)

Split attention heads in the projection in query, key, and value heads.

Parameters:
  • projection (Tensor) –

    The fused query, key, value projection.

    Shape: (batch_size, seq_len, (n_query_heads + 2 * n_key_value_heads) * head_width)

  • head_width (int) – Head width.

  • n_query_heads (int) – Number of query heads.

  • n_key_value_heads (int) – Number of key/value heads.

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

Query, key, value tensors.

Shapes:

  • Query: (batch_size, n_query_heads, seq_len, head_width)

  • Key: (batch_size, n_key_value_heads, seq_len, head_width)

  • Value: (batch_size, n_key_value_heads, seq_len, head_width)

class curated_transformers.layers.AttentionHeads(*, n_query_heads, n_key_value_heads, qkv_split)

Construct an attention head configuration. This constructor must not be used directly, its signature may change even within a semver version. Use the factory methods instead.

Parameters:
  • n_query_heads (int) – Number of query heads.

  • n_key_value_heads (int) – Number of key/value heads.

  • qkv_split (QkvSplit) – How query, key, and value should be split when using MERGED_SPLIT_AFTER. Not used for other query, key, value modes.

classmethod key_value_broadcast(*, n_query_heads, n_key_value_heads, qkv_split)

Construct a head configuration where query has a larger number of heads than key and value. Key/value heads are broadcast to correspond to the number of query heads.

Parameters:
  • n_query_heads (int) – Number of attention heads. Must be a multiple of n_key_value_heads.

  • n_key_value_heads (int) – Number of key and value heads.

  • qkv_split (QkvSplit) – How query, key, and value should be split when using MERGED_SPLIT_AFTER. Not used for other query, key, value modes.

Return type:

AttentionHeads

classmethod multi_query(n_query_heads, qkv_split)

Construct a multi-query attention configuration: key has one head, value has one head, query has n_query_heads heads (Shazeer et al., 2019). The key head and the value head are broadcast to the shape of the query.

Parameters:
  • n_query_heads (int) – Number of query heads.

  • qkv_split (QkvSplit) – How query, key, and value should be split when using MERGED_SPLIT_AFTER. Not used for other query, key, value modes.

Return type:

AttentionHeads

classmethod uniform(n_attention_heads, qkv_split)

Construct a head configuration where query, key, and value have the same number of attention heads.

Parameters:
  • n_attention_heads (int) – Number of attention heads.

  • qkv_split (QkvSplit) – How query, key, and value should be split when using MERGED_SPLIT_AFTER. Not used for other query, key, value modes.

Return type:

AttentionHeads

class curated_transformers.layers.AttentionMask(bool_mask)

Mask for attention calculation. Sequence elements for which the corresponding mask element is set to False are ignored during attention calculation.

Parameters:

bool_mask (Tensor) – The boolean mask.

apply_logit_mask(input)

Use the attention mask to mask attention logits.

Parameters:

input (Tensor) –

Attention logits to apply the mask to.

Shape: (batch_size, heads, query_len, key_len)

Return type:

Tensor

Returns:

Logits with the attention mask applied.

Shape: (batch_size, heads, query_len, key_len)

property device: device

Return the device of the mask.

dim()

Return the number of dimensions in the mask.

Return type:

int

extend_length(count, fill_value)

Extend the attention mask in the sequence length dimension by the given value.

Parameters:
  • count (int) – Number of new elements to insert.

  • fill_value (bool) – Value to store in the new elements.

Return type:

AttentionMask

Returns:

Extended mask.

filter_batch_items(mask)

Filter batch sequences from the attention mask.

Sequences for which the mask is True are retained.

Parameters:

mask (Tensor) –

Mask of batch items to retain.

Shape: (batch_size,)

Return type:

AttentionMask

Returns:

Filtered mask.

logit_mask(dtype)

Generate the logit mask for the given dtype.

Elements of the mask that are False are set to the minimum value of the dtype and the rest to zero. During softmax calculation, adding this mask to the logits will result in (near-)zero probabilities for the elements that are False.

Parameters:

dtype (dtype) – Data type of the logit mask.

Return type:

Tensor

Returns:

Logit mask.

merge_mask(other)

Merge this attention mask with another attention mask.

Parameters:

other (AttentionMask) – Attention mask to merge.

Return type:

AttentionMask

Returns:

Merged mask.

property shape: Size

Return the shape of the mask.

class curated_transformers.layers.KeyValueCache(key, value)

Cache type for layers that cache keys and values.

Parameters:
  • key (Tensor) – Key.

  • value (Tensor) – Value.

class curated_transformers.layers.AttentionScorer(*args, **kwargs)

Base class of attention scoring implementations.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(*, query, key, value, attention_mask, use_causal_mask)

Apply attention scores to the given key, query and value.

Sequence elements that are marked with False in the attention mask are ignored by the attention mechanism (if a mask is provided).

Parameters:
  • query (Tensor) –

    Query tensor.

    Shape: (batch_size, heads, seq_len, width)

  • key (Tensor) –

    Key tensor.

    Shape: (batch_size, heads, seq_len, width)

  • value (Tensor) –

    Value tensor.

    Shape: (batch_size, heads, seq_len, width)

  • attention_mask (AttentionMask) – Attention mask. Sequence elements for which the corresponding mask element is set to False are ignored in attention.

  • use_causal_mask (bool) – Mask out succeeding sequence elements when True.

Return type:

Tensor

Returns:

Attention values.

Shape: (batch_size, heads, seq_len, width)

class curated_transformers.layers.AttentionLinearBiases(*, n_attention_heads, is_causal, is_inverted)

Bases: Module

ALiBi: Linear biases for attention (Press et al., 2022).

Construct an ALiBi module.

Parameters:
  • n_attention_heads (int) – Number of attention heads.

  • is_causal (bool) – Use causal attention.

  • invert – If True, the biases are inverted, i.e., penalties become rewards.

forward(*, attention_scores, inplace=True)

Apply linear biases to (unmasked) attention scores.

Parameters:
  • attention_scores (Tensor) –

    Attention scores.

    Shape: (batch_size, heads, query_len, key_len)

  • inplace (bool) – Update attention scores inplace.

Return type:

Tensor

Returns:

Attention scores with linear biases.

Shape: (batch_size, heads, query_len, key_len)

class curated_transformers.layers.ScaledDotProductAttention(*, dropout_prob, linear_biases)

Bases: AttentionScorer

Scaled dot-product attention (Vaswani et al., 2017).

Construct a scaled dot-product attention module.

Parameters:
forward(*, query, key, value, attention_mask, use_causal_mask)

Apply attention scores to the given key, query and value.

Sequence elements that are marked with False in the attention mask are ignored by the attention mechanism (if a mask is provided).

Parameters:
  • query (Tensor) –

    Query tensor.

    Shape: (batch_size, heads, seq_len, width)

  • key (Tensor) –

    Key tensor.

    Shape: (batch_size, heads, seq_len, width)

  • value (Tensor) –

    Value tensor.

    Shape: (batch_size, heads, seq_len, width)

  • attention_mask (AttentionMask) – Attention mask. Sequence elements for which the corresponding mask element is set to False are ignored in attention.

  • use_causal_mask (bool) – Mask out succeeding sequence elements when True.

Return type:

Tensor

Returns:

Attention values.

Shape: (batch_size, heads, seq_len, width)

class curated_transformers.layers.SelfAttention(*, attention_heads, attention_scorer, hidden_width, qkv_mode, rotary_embeds=None, use_bias, device=None)

Bases: Module

Transformer self-attention layer (Vaswani et al., 2017).

Construct a self-attention layer with rotary position embeddings and attention linear biases.

Parameters:
  • attention_heads (AttentionHeads) – Attention head configuration.

  • attention_scorer (AttentionScorer) – Attention scorer used to calculate the attention values.

  • hidden_width (int) – Hidden width of the layer.

  • qkv_mode (QkvMode) – Handling mode for query, key and value.

  • rotary_embeds (Optional[QueryKeyRotaryEmbeddings]) – Rotary embeddings. Rotary embeddings will not be used when set to None.

  • use_bias (bool) – Use biases for linear layers.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(input, attention_mask, use_causal_mask=False, cache=None, store_cache=False, positions=None)

Apply self-attention layer to the input.

Parameters:
  • input (Tensor) –

    Input to apply self-attention to.

    Shape: (batch_size, seq_len, width)

  • attention_mask (AttentionMask) – Attention mask. Sequence elements for which the corresponding mask element is set to False are ignored in attention.

  • use_causal_mask (bool) – Mask out succeeding sequence elements when True.

  • cache (Optional[KeyValueCache]) – Key/value cache to avoid recomputing key/value representations for tokens that were previously seen.

  • store_cache (bool) – Whether to cache the key/value representations for future reuse.

  • positions (Optional[Tensor]) –

    Input positions. Positions are needed to look up rotary embeddings. Normally, these positions are calculated automatically. But if the positions deviate for some reason, they can be provided through this argument.

    Shape: (batch_size, seq_len)

Return type:

Tuple[Tensor, Optional[KeyValueCache]]

Returns:

Layer output.

Shape: (batch_size, seq_len, width)

Embeddings

These modules implement various positional embeddings used by the Transformer.

class curated_transformers.layers.SinusoidalPositionalEmbedding(*, width, max_len, normalize=True, device=None)

Bases: Module

Sinusoidal positional embeddings (Vaswani et al., 2017).

Construct a sinusoidal positional embedding module.

Parameters:
  • width (int) – Width of the embedding.

  • max_len (int) – Maximum length of the embedding.

  • normalize – Perform L2 normalization of the embedding.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(input)

Returns the positional embedding for the input.

Parameters:

input (Tensor) –

Input tensor.

Shape: (batch_size, seq_len)

Return type:

Tensor

Returns:

Positional embedding for the input.

Shape: (seq_len, width)

class curated_transformers.layers.RotaryEmbeddings(width, *, seq_len=512, base=10000, device=None)

Bases: Module

Rotary embeddings (Su et al., 2021).

Construct a rotary embedding module. The rotary embedding will be precomputed for up to seq_len positions. The embedding will be recomputed when a longer sequence is found in the input.

Parameters:
  • width (int) – Rotary embedding width. Must be even.

  • seq_len (int) – Number of positions to initially precompute.

  • base (int) – The base used for \(\theta_i\). Determines the cycle length of the embeddings.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(input, *, positions=None)

Apply rotary embeddings to the input.

Parameters:
  • input (Tensor) –

    Input to apply the rotary embeddings to.

    Shape: (batch_size, n_heads, seq_len, width_per_head)

  • positions (Optional[Tensor]) –

    Positions of the inputs. If no positions are provided, they are assumed to be [0, seq_len).

    Shape: (batch_size, seq_len)

Returns:

Input with the rotary embeddings applied.

Shape: (batch_size, n_heads, seq_len, width_per_head)

class curated_transformers.layers.QueryKeyRotaryEmbeddings(*, base=10000, fraction, head_width, device=None)

Bases: Module

Rotary embeddings (Su et al., 2021) applied to query and key representations.

Construct a rotary embedding module.

Parameters:
  • base (int) – Base in signifying the rotary embedding period.

  • fraction (float) – Fraction of hidden width to apply rotary embeddings to. Must be in [0,1].

  • head_width (int) – Width of key and value heads.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(*, query, key, cache=None, positions=None)

Apply rotary embeddings to the query and key.

Parameters:
  • query (Tensor) –

    Query representations.

    Shape: (batch_size, head, seq_len, width_per_head)

  • key (Tensor) –

    Key representations.

    Shape: (batch_size, head, seq_len, width_per_head)

  • cache (Optional[KeyValueCache]) – Key/value cache to avoid recomputing key/value representations for tokens that were previously seen.

  • positions (Optional[Tensor]) –

    Input positions. Positions are needed to look up rotary embeddings. Normally, these positions are calculated automatically. But if the positions deviate for some reason, they can be provided through this argument.

    Shape: (batch_size, seq_len)

Return type:

Tuple[Tensor, Tensor]

Returns:

Query and key with the rotary embeddings applied.

Shape: (batch_size, head, seq_len, width_per_head)

Feed-forward Layers

class curated_transformers.layers.PointwiseFeedForward(*, activation, hidden_width, intermediate_width, use_bias, use_gate, device=None)

Bases: Module

Point-wise feed-forward layer (Vaswani et al., 2017).

This layer is applied pointwise, meaning that the same transformation is applied to each sequence element. This transformation is:

\[g(xW_1 + b_1)W_2 + b_2\]

\(W_1\) and \(b_1\) transform the input to an intermediate width, \(g\) is a non-linear activation function and \(W_2\) and \(b_2\) transform the output of the activation back to the input width.

Gated Linear Units (Dauphin et al., 2016; Shazeer, 2020) are also supported. Gating applies the following transformation:

\[(g(xW_g + b_g) * (xW_1 + b_1))W_2 + b_2\]

\(W_g\) and \(b_g\) are the affine transformation for the gate.

Construct a pointwise feed-forward layer module.

Parameters:
  • activation (Module) – Activation used by the pointwise feed-forward layers. The hidden input shape must be the same as the output shape (as is typical for elementwise activations).

  • hidden_width (int) – The input and output width of the layer.

  • intermediate_width (int) – The width of the projection to which the non-linearity is applied.

  • use_bias (bool) – Use biases for linear layers.

  • use_gate (bool) – Use Gated Linear Units.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(input)

Apply the point-wise feed-forward layer to the input.

Parameters:

input (Tensor) –

Input.

Shape: (batch_size, seq_len, width)

Return type:

Tensor

Returns:

Layer output.

Shape: (batch_size, seq_len, width)

Activations

class curated_transformers.layers.Activation(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)

Bases: Enum

Activation functions.

GELU = 'gelu'

Gaussian Error Linear Unit (Hendrycks et al., 2016).

GELUFast = 'gelu_fast'

Gaussian Error Linear Unit (Hendrycks et al., 2016) approximation used by GPT-NeoX (Black et al., 2022).

GELUNew = 'gelu_new'

Gaussian Error Linear Unit (Hendrycks et al., 2016) approximation.

ReLU = 'relu'

Rectified Linear Unit (Fukushima, 1969).

SiLU = 'silu'

Sigmoid Linear Unit (Hendrycks et al., 2016).

property module: Type[Module]

Get the PyTorch module for the activation function.

class curated_transformers.layers.GELUFast(*args, **kwargs)

Bases: Module

GELU (Hendrycks et al., 2016) approximation used by GPT-NeoX (Black et al., 2022).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input)

Apply the GELU activation on the input.

Parameters:

input (Tensor) –

Input tensor.

Shape: (batch_size, seq_len, width)

Return type:

Tensor

class curated_transformers.layers.GELUNew(*args, **kwargs)

Bases: Module

GELU (Hendrycks et al., 2016) approximation, called gelu_new in many transformer models.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input)

Apply the GELU activation on the input.

Parameters:

input (Tensor) –

Input tensor.

Shape: (batch_size, seq_len, width)

Return type:

Tensor

Normalization

class curated_transformers.layers.RMSNorm(width, *, eps, device=None)

Bases: Module

Root Mean Square (RMS) normalization (Zhang et al., 2019).

Construct a RMS normalization module.

Parameters:
  • width (int) – The (hidden) width of the representations that RMS normalization will be applied to.

  • eps (float) – Epsilon to avoid division by zero.

  • device (Optional[device]) – Device on which the module is to be initialized.

forward(input)

Apply RMS normalization to a tensor.

Parameters:

input (Tensor) – The tensor to apply normalization to.

Return type:

Tensor

Returns:

Normalized tensor.

Model Outputs

These dataclasses encapsulate the outputs produced by the different modules.

class curated_transformers.models.ModelOutput(all_outputs)

Base class for model outputs.

Parameters:

all_outputs (List[Tensor]) – The first element is the output of the embedding layer. The rest of the elements are the states of each encoder hidden layer respectively.

property all_hidden_layer_states: List[Tensor]

Return the hidden representation of all the layers.

Returns:

Hidden representations of all the layers.

Shape: (batch_size, seq_len, width)

property embedding_layer: Tensor

Return the output of the embedding layer.

Returns:

Embedding layer output.

Shape: (batch_size, seq_len, width)

hidden_layer_states(idx)

Return the hidden representations of a given layer.

Parameters:

idx (int) – Layer index. Must be in [0, n_hidden_layers).

Return type:

Tensor

Returns:

Hidden representation of the layer.

Shape: (batch_size, seq_len, width)

property last_hidden_layer_state: Tensor

Return the hidden representation of the last layer.

Returns:

Last hidden representation of the last layer.

Shape: (batch_size, seq_len, width)

class curated_transformers.models.ModelOutputWithCache(all_outputs, cache)

Bases: Generic[CacheT], ModelOutput

Output of decoder modules.

Parameters:

cache (Optional[List[TypeVar(CacheT, bound= CacheProtocol)]]) – Model cache. The cache can be used with future calls to a model to reuse computations for efficiency

class curated_transformers.models.CausalLMOutputWithCache(all_outputs, cache, logits)

Bases: Generic[CacheT], ModelOutputWithCache[CacheT]

Output of causal language model modules.

Parameters:

logits (Tensor) – Logits of the distributions of predicted tokens.

Model Configs

These dataclasses encapsulate the configurable parameters of the Transformer model.

class curated_transformers.models.RotaryEmbeddingConfig(rotary_base, rotary_fraction)

Configuration options for rotary embeddings (Su et al., 2021).

Parameters:
  • rotary_base (int) – Base in signifying the rotary embedding period.

  • rotary_fraction (float) – Fraction of hidden width to apply rotary embeddings to. Must be in [0,1].

class curated_transformers.models.TransformerAttentionLayerConfig(dropout_prob, hidden_width, n_query_heads, n_key_value_heads, rotary_embeddings, use_alibi, use_bias, use_parallel_attention)

Configuration options for self-attention.

Parameters:
  • dropout_prob (float) – Dropout probabilty to apply after attention.

  • hidden_width (int) – Hidden width of the transformer.

  • n_query_heads (int) – Number of attention heads.

  • n_key_value_heads (int) – Number of key and value heads.

  • rotary_embeddings (Optional[RotaryEmbeddingConfig]) – Rotary embedding configuration.

  • use_alibi (bool) – Use ALiBi linear biases.

  • use_bias (bool) – Use bias in linear layers.

  • use_parallel_attention (bool) – Use parallel attention.

class curated_transformers.models.TransformerEmbeddingLayerConfig(dropout_prob, embedding_width, layer_norm_eps, n_positions, n_pieces, n_types)

Configuration options for embeddings.

Parameters:
  • dropout_prob (float) – Dropout probabilty for the embedding layer.

  • embedding_width (int) – Width of the embedding representations.

  • layer_norm_eps (float) – Epsilon for layer normalization.

  • n_positions (Optional[int]) – Maximum length of position embeddings.

  • n_pieces (int) – Vocabulary size (number of embeddings).

  • n_types (Optional[int]) – Token type vocabulary size (number of token type embeddings).

class curated_transformers.models.TransformerFeedForwardLayerConfig(activation, hidden_width, intermediate_width, use_bias, use_gate)

Configuration options for transformer feed-forward layers.

Parameters:
  • activation (Activation) – Activation in the feed-forward layer

  • hidden_width (int) – Hidden width of the transformer.

  • intermediate_width (int) – Intermediate width in the feed-forward layer.

  • use_bias (bool) – Use bias in linear layers.

  • use_gate (bool) – Use Gated Linear Units.

class curated_transformers.models.TransformerLayerConfig(attention, dropout_prob, feedforward, layer_norm_eps, n_hidden_layers)

Configuration options for transformer layers.

Parameters:
class curated_transformers.models.TransformerConfig(embedding, layer, dtype)

Configuration options for a transformer model.

Parameters: