Generation

Models

These classes provide the interface for performing text generation using causal LMs.

class curated_transformers.generation.Generator(model)

Bases: Generic[CacheT]

Generator base class for causal language models.

Construct a generator.

Parameters:

model (CausalLMModule[Any, TypeVar(CacheT, bound= CacheProtocol)]) – The causal language model to generate with.

__call__(*, attention_mask, ids, config)

Alias for generate().

Return type:

Iterator[Tuple[Tensor, Tensor]]

generate(*, attention_mask, ids, config)

Generate text, starting from the given piece identifiers.

The generator returns an iterator over tuples. Each tuple contains:
  1. A tensor with sequence identifiers.

  2. A tensor with the next piece identifiers.

The sequence identifiers are numbered 0..batch and are necessary because some sequences may finish generation earliers than others. The sequence identifiers allow the caller to map the generated pieces back to the original input sequences.

Parameters:
  • ids (Tensor) –

    Batch of piece identifiers to start generating from.

    Shape: (batch_size, seq_len)

  • attention_mask (AttentionMask) – Attention mask that masks out pieces that should not be attended to.

  • config (GeneratorConfig) – Generator configuraton.

Return type:

Iterator[Tuple[Tensor, Tensor]]

Returns:

An iterator over tuples. Each tuple contains a tensor with the sequence identifiers and a tensor with the next piece identier.

Shape: (batch_unfinished,)

class curated_transformers.generation.StringGenerator(tokenizer, generator)

Bases: Generic[CacheT]

Generator wrapper that takes textual input and outputs generated strings. It wraps a generator and uses a tokenizer to split the input into pieces and decode the output pieces.

Construct a string generator.

Parameters:
  • tokenizer (TokenizerBase) – Tokenizer for piece processing.

  • generator (Generator[TypeVar(CacheT, bound= CacheProtocol)]) – Generator to wrap.

__call__(prompts, config)

Alias for generate().

Return type:

List[str]

generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

class curated_transformers.generation.GeneratorWrapper

Bases: ABC

Model-specific wrapper for curated_transformers.generation.Generator.

__call__(prompts, config)

Alias for generate().

Return type:

List[str]

abstract generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

class curated_transformers.generation.DefaultGenerator(tokenizer, causal_lm, default_config=None)

Bases: Generic[CacheT], GeneratorWrapper, FromHF

Generator wrapper for models that do not need specific prompting.

Construct a generic generator.

Parameters:
  • tokenizer (TokenizerBase) – A tokenizer.

  • causal_lm (CausalLMModule[Any, TypeVar(CacheT, bound= CacheProtocol)]) – A causal language model.

  • default_config (Optional[GeneratorConfig]) – Configuration to use as a default when the configuration provided to the generate method is underspecified. For instance, if the end-of-sequence identifier is None in the generation configuration, it will be taken from the default configuration.

__call__(prompts, config)

Alias for generate().

Return type:

List[str]

classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)

Construct a generator and load its parameters from Hugging Face Hub.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(Self, bound= DefaultGenerator)

Returns:

Generator with the parameters loaded.

classmethod from_hf_hub_to_cache(*, name, revision='main')

Download the generator’s model and tokenizer from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the generator will load the model and the tokenizer from disk.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

preprocess_prompts(prompts)

Prepare a list of prompts for generation.

Parameters:

prompts (List[str]) – The prompts to prepare.

Return type:

List[InputChunks]

Returns:

Prepared prompts.

class curated_transformers.generation.DollyV2Generator(tokenizer, causal_lm)

Bases: DefaultGenerator

Generator for Dolly v2 model variants.

Construct a Dolly v2 generator.

Parameters:
__call__(prompts, config)

Alias for generate().

Return type:

List[str]

classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)

Construct a generator and load its parameters from Hugging Face Hub.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(Self, bound= DefaultGenerator)

Returns:

Generator with the parameters loaded.

classmethod from_hf_hub_to_cache(*, name, revision='main')

Download the generator’s model and tokenizer from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the generator will load the model and the tokenizer from disk.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

preprocess_prompts(prompts)

Prepare a list of prompts for generation.

Parameters:

prompts (List[str]) – The prompts to prepare.

Return type:

List[InputChunks]

Returns:

Prepared prompts.

class curated_transformers.generation.FalconGenerator(tokenizer, causal_lm)

Bases: DefaultGenerator, FromHF

Generator for Falcon model variants.

Construct a Falcon generator.

Parameters:
__call__(prompts, config)

Alias for generate().

Return type:

List[str]

classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)

Construct a generator and load its parameters from Hugging Face Hub.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(Self, bound= DefaultGenerator)

Returns:

Generator with the parameters loaded.

classmethod from_hf_hub_to_cache(*, name, revision='main')

Download the generator’s model and tokenizer from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the generator will load the model and the tokenizer from disk.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

preprocess_prompts(prompts)

Prepare a list of prompts for generation.

Parameters:

prompts (List[str]) – The prompts to prepare.

Return type:

List[InputChunks]

Returns:

Prepared prompts.

class curated_transformers.generation.LlamaGenerator(tokenizer, causal_lm)

Bases: DefaultGenerator, FromHF

Generator for Llama and Llama 2 model variants.

Construct a Llama generator.

Parameters:
__call__(prompts, config)

Alias for generate().

Return type:

List[str]

classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)

Construct a generator and load its parameters from Hugging Face Hub.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(Self, bound= DefaultGenerator)

Returns:

Generator with the parameters loaded.

classmethod from_hf_hub_to_cache(*, name, revision='main')

Download the generator’s model and tokenizer from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the generator will load the model and the tokenizer from disk.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

preprocess_prompts(prompts)

Prepare a list of prompts for generation.

Parameters:

prompts (List[str]) – The prompts to prepare.

Return type:

List[InputChunks]

Returns:

Prepared prompts.

class curated_transformers.generation.MPTGenerator(tokenizer, causal_lm)

Bases: DefaultGenerator, FromHF

Generator for MPT model variants.

Construct an MPT generator.

Parameters:
  • tokenizer (Tokenizer) – An MPT tokenizer.

  • causal_lm (MPTCausalLM) – An MPT causal language model.

__call__(prompts, config)

Alias for generate().

Return type:

List[str]

classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)

Construct a generator and load its parameters from Hugging Face Hub.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(Self, bound= DefaultGenerator)

Returns:

Generator with the parameters loaded.

classmethod from_hf_hub_to_cache(*, name, revision='main')

Download the generator’s model and tokenizer from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the generator will load the model and the tokenizer from disk.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

generate(prompts, config)

Generate text using the given prompts.

Parameters:
Return type:

List[str]

Returns:

Strings generated for the prompts.

preprocess_prompts(prompts)

Prepare a list of prompts for generation.

Parameters:

prompts (List[str]) – The prompts to prepare.

Return type:

List[InputChunks]

Returns:

Prepared prompts.

Downloading

Each generator type provides a from_hf_hub function that will load a model from Hugging Face Hub. If you want to load a generator without committing to a specific generator type, you can use the AutoGenerator class. This class also provides a from_hf_hub method but will try to infer the correct type automatically.

class curated_transformers.generation.AutoGenerator

Causal LM generator loaded from the Hugging Face Model Hub.

Attention

This class can currently only be used with the following models:

  • Models based on Dolly v2 (contain dolly-v2 in the name).

  • Models based on Falcon (contain falcon in the name).

  • Models based on Llama (contain llama in the name).

  • Models based on MPT (contain mpt in the name).

classmethod from_fsspec(*, fs, model_path, fsspec_args=None, device=None, quantization_config=None)

Construct a module and load its parameters from a fsspec filesystem.

Parameters:
  • fs (AbstractFileSystem) – The filesystem to load the model from.

  • model_path (str) – The path of the model on the filesystem.

  • fsspec_args (Optional[FsspecArgs]) – Implementation-specific keyword arguments to pass to fsspec filesystem operations.

  • device (Optional[device]) – Device on which the model is initialized.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(ModelT)

Returns:

Module with the parameters loaded.

classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)

Construct and load a model or a generator from Hugging Face Hub.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

GeneratorWrapper

Returns:

Loaded model or generator.

classmethod from_hf_hub_to_cache(*, name, revision='main')

Download the model’s weights from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the model will read the weights from disk. If the weights are already cached, this is a no-op.

Parameters:
  • name (str) – Model name.

  • revision (str) – Model revision.

abstract classmethod from_repo(*, repo, device=None, quantization_config=None)

Construct and load a model or a generator from a repository.

Parameters:
  • repository – The repository to load from.

  • device (Optional[device]) – Device on which to initialize the model.

  • quantization_config (Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.

Return type:

TypeVar(ModelT)

Returns:

Loaded model or generator.

Configuration

These classes represent the different parameters used by generators.

class curated_transformers.generation.GeneratorConfig(masked_pieces=None, eos_id=None, max_generated_pieces=None)

Configuration of the generator.

Parameters:
  • masked_pieces (Optional[Set[int]]) – Vocabulary pieces that should be masked out.

  • eos_id (Optional[int]) – End-of-sequence identifier that should end the generation of a sequence when predicted. When this value is set to None, it is the responsibility of the generator to set it.

  • max_generated_pieces (Optional[int]) – The maximum number of generation steps. This condition is a noop for values less than 1. When this value is set to None, it is the responsibility of the generator to set it.

abstract logits_transform()

Get logit transform for the configuration.

Return type:

LogitsTransform

Returns:

Logits transform. Usually multiple composed transforms.

stop_condition()

Get the stop condition for the configuration.

Return type:

StopCondition

Returns:

Stop condition. Usually multiple composed conditions.

class curated_transformers.generation.GreedyGeneratorConfig(masked_pieces=None, eos_id=None, max_generated_pieces=None)

Bases: GeneratorConfig

Configuration for greedy generation.

Greedy generation always selects the highest-probability piece, leading to deterministic generation.

logits_transform()

Get logit transform for the configuration.

Return type:

LogitsTransform

Returns:

Logits transform. Usually multiple composed transforms.

class curated_transformers.generation.SampleGeneratorConfig(masked_pieces=None, eos_id=None, max_generated_pieces=None, temperature=1.0, top_k=0, top_p=1.0)

Bases: GeneratorConfig

Configuration for generation with sampling.

Sampling-based generation samples pieces from probability distributions. Generation is non-deterministic as a result, but provides more varied output.

Parameters:
  • temperature (float) –

    Softmax temperature. For a temperature T:

    • T = 1: the distribution is not changed.

    • T < 1: the entropy of the distribution is decreased.

    • T > 1: the entropy of the distribution is increased.

  • top_k (int) – Sample from top-k highest-probability pieces. top_k < 1 disables top-k filtering.

  • top_p (float) – Sample from highest probability pieces the smallest set, such that their cumulative probability is >= p. top_p = 1.0 disables top-p filtering.

logits_transform()

Get logit transform for the configuration.

Return type:

LogitsTransform

Returns:

Logits transform. Usually multiple composed transforms.

class curated_transformers.generation.StopCondition

Base class for generation stop conditions.

abstract update_completed(*, state, completed_exclude, completed_include)

Update completed sequences according to the stop condition.

Parameters:
  • completed_exclude (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should not be emitted.

  • completed_include (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should be emitted.

class curated_transformers.generation.CompoundStopCondition(iterable=(), /)

Bases: List[StopCondition], StopCondition

Sequentially apply multiple stop conditions.

update_completed(*, state, completed_exclude, completed_include)

Update completed sequences according to the stop condition.

Parameters:
  • completed_exclude (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should not be emitted.

  • completed_include (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should be emitted.

class curated_transformers.generation.EndOfSequenceCondition(eos_id)

Bases: StopCondition

Stop when the end-of-sequence piece is predicted.

Construct the stop condition.

Parameters:

eos_id (int) – End-of-sequence identifier that marks the end of a generated sequence.

update_completed(*, state, completed_exclude, completed_include)

Update completed sequences according to the stop condition.

Parameters:
  • completed_exclude (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should not be emitted.

  • completed_include (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should be emitted.

class curated_transformers.generation.MaxGeneratedPiecesCondition(max_generated_pieces)

Bases: StopCondition

Stop after generating a maximum number of pieces.

Construct the stop condition.

Parameters:

max_generated_pieces (int) – The maximum number of generated pieces. This condition is a noop for values less than 1.

update_completed(*, state, completed_exclude, completed_include)

Update completed sequences according to the stop condition.

Parameters:
  • completed_exclude (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should not be emitted.

  • completed_include (Tensor) – Output tensor marking which sequences are completed and for which the last generated piece should be emitted.

class curated_transformers.generation.LogitsTransform

A logits transform changes the logits of a softmax distribution in some way. For instance, TemperatureTransform changes the temperature of the softmax distribution.

class curated_transformers.generation.CompoundLogitsTransform(initlist=None)

Bases: UserList, LogitsTransform

Sequentially apply multiple logit transforms.

class curated_transformers.generation.TopKTransform(k)

Bases: LogitsTransform

Set the probability of non-top-k classes to zero. The probability of the classes that are zeroed out is redistributed across the top-k classes.

Construct a top-k logits transform.

Parameters:

k (int) – The value of k in top-k. The transform is a no-op for values less than 1.

class curated_transformers.generation.TopPTransform(p)

Bases: LogitsTransform

Keep the smallest possible set of most probable vocab items, such that their cumulative probability is >= p. Sampling using the top-p transform is also known as nucleus sampling (Holzman et al., 2019). The probability of the items that are masked out is redistributed across the top-p items.

Construct a top-p logits transform.

Parameters:

p (float) – The value of p in top-p. The transform is a no-op for p = 1.0.

class curated_transformers.generation.TemperatureTransform(temperature=1.0)

Bases: LogitsTransform

Apply temperature to the softmax distribution. Given the temperature T and logits z(y|x):

\[p(y|x) = softmax(z(y|x)/T)\]

For a temperature T:

  • T = 1: the distribution is not changed.

  • T < 1: the entropy of the distribution is decreased.

  • T > 1: the entropy of the distribution is increased.

Create a temperature transform with a given temperature.

Parameters:

temperature (float) – The temperature. Must be a non-zero positive value.

class curated_transformers.generation.VocabMaskTransform(pieces_to_mask)

Bases: LogitsTransform

Set the probability of specific vocabulary pieces to zero.

Construct a mask logits transform.

Parameters:

pieces_to_mask (Iterable[int]) – Identifers pertaining to the vocabulary pieces that need to be masked. An empty iterable results in a no-op.