Usage

Installation

To use Curated Transformers, first install it using pip:

(.venv) $ pip install curated-transformers

If support for quantization is required, use the quantization variant to automatically install the necessary dependencies:

(.venv) $ pip install curated-transformers[quantization]

CUDA Support

The default Linux build of PyTorch is built with CUDA 11.7 support. You should explicitly install a CUDA build in the following cases:

  • If you want to use Curated Transformers on Windows.

  • If you want to use Curated Transformers on Linux with Ada-generation GPUs.

    The standard PyTorch build supports Ada GPUs, but you can get considerable performance improvements by installing PyTorch with CUDA 11.8 support.

In both cases, you can install PyTorch with:

(.venv) $ pip install torch --index-url https://download.pytorch.org/whl/cu118

Text Generation Using Causal LMs

Curated Transformers provides infrastructure to perform open-ended text generation using decoder-only causal language models. The Generator class wraps a CausalLMModule and its corresponding tokenizer. It provides a generic interface to generate outputs from the wrapped module in an auto-regressive fashion. GeneratorConfig specifies the parameters used by the generator such as stopping conditions and sampling parameters.

The AutoGenerator class can be used to directly load a supported causal LM model and generate text with it.

import torch
from curated_transformers.generation import (
   AutoGenerator,
   GreedyGeneratorConfig,
   SampleGeneratorConfig,
)

generator = AutoGenerator.from_hf_hub(
   name="databricks/dolly-v2-3b", device=torch.device("cuda", index=0)
)

sample_config = SampleGeneratorConfig(temperature=1.0, top_k=2)
greedy_config = GreedyGeneratorConfig()

prompts = [
   "To which David Bowie song do these lyrics belong: \"Oh man, look at those cavemen go! It's the freakiest show\"?",
   "What is spaCy?"
]
sample_outputs = generator(prompts, config=sample_config)
greedy_outputs = generator(prompts, config=greedy_config)

print(f"Sampling outputs: {sample_outputs}")
print(f"Greedy outputs: {greedy_outputs}")

For more information about the different configs and generators supported by Curated Transformers, see Generation.

Loading a Model

Hugging Face Hub

Curated Transformers allows users to easily load model weights from the Hugging Face Model Hub. All models provide a from_hf_hub method that allows directly loading pre-trained model parameters from Hugging Face Model Hub.

import torch
from curated_transformers.models import BERTEncoder
from curated_transformers.models import GPTNeoXDecoder

encoder = BERTEncoder.from_hf_hub(
   name="bert-base-uncased",
   revision="main",
   device=torch.device("cuda", index=0),
)

decoder = GPTNeoXDecoder.from_hf_hub(name="databricks/dolly-v2-3b", revision="main")

The AutoEncoder, AutoDecoder and AutoCausalLM classes can be used to automatically infer the model architecture.

from curated_transformers.models import (
   AutoCausalLM,
   AutoDecoder,
   AutoEncoder,
)

encoder = AutoEncoder.from_hf_hub(
   name="bert-base-uncased",
   revision="main",
)

decoder = AutoDecoder.from_hf_hub(name="databricks/dolly-v2-3b", revision="main")

lm = AutoCausalLM.from_hf_hub(name="databricks/dolly-v2-3b", revision="main")

fsspec filesystem

Curated Transformers also supports loading models from fsspec filesystems. This makes it possible to load local models or loading models from cloud services without using any local storage. A model can be downloaded from an fsspec filesystem using the from_fsspec method.

import torch
from curated_transformers.models import BERTEncoder
from curated_transformers.repository import FsspecArgs
from fsspec.implementations.local import LocalFileSystem
from huggingface_hub import HfFileSystem

encoder = BERTEncoder.from_fsspec(
   fs=LocalFileSystem(),
   model_path="/srv/models/bert-base-uncased",
   device=torch.device("cuda", index=0),
)

# Pass additional arguments to the specific fsspec implementation.
encoder = BERTEncoder.from_fsspec(
   fs=HfFileSystem(),
   model_path="bert-base-uncased",
   fsspec_args=FsspecArgs(revision= "a265f773a47193eed794233aa2a0f0bb6d3eaa63"),
   device=torch.device("cuda", index=0),
)

Quantization

Curated Transformers implements dynamic 8-bit and 4-bit quantization of models by leveraging the bitsandbytes library. When loading models using the from_hf_hub method, an optional BitsAndBytesConfig instance can be passed to the method to opt into dynamic quantization of model parameters. Quantization requires the model to be loaded to a CUDA GPU by additionally passing the device argument to the method.

import torch
from curated_transformers.generation import AutoGenerator
from curated_transformers.quantization.bnb import BitsAndBytesConfig, Dtype4Bit

generator_8bit = AutoGenerator.from_hf_hub(
   name="databricks/dolly-v2-3b",
   device=torch.device("cuda", index=0),
   quantization_config=BitsAndBytesConfig.for_8bit(
      outlier_threshold=6.0, finetunable=False
   ),
)

generator_4bit = AutoGenerator.from_hf_hub(
   name="databricks/dolly-v2-3b",
   device=torch.device("cuda", index=0),
   quantization_config=BitsAndBytesConfig.for_4bit(
      quantization_dtype=Dtype4Bit.FP4,
      compute_dtype=torch.bfloat16,
      double_quantization=True,
   ),
)

Loading a Tokenizer

To train or run inference on the models, one has to tokenize the inputs with a compatible tokenizer. Curated Transformers supports tokenizers implemented by the Hugging Face tokenizers library and certain model-specific tokenizers that are implemented using the Curated Tokenizers library. The Tokenizer class encapsulates the former and the LegacyTokenizer class the latter.

In both cases, one can use the AutoTokenizer class to automatically infer the correct tokenizer type and construct a Curated Transformers tokenizer that implements the TokenizerBase interface.

from curated_transformers.tokenizers import AutoTokenizer

tokenizer = AutoTokenizer.from_hf_hub(
   name="bert-base-uncased",
   revision="main",
)

Text Encoding

Note

Currently, Curated Transformers only supports inference with models.

In addition to text generation, one can also run inference on the inputs to produce their dense representations.

import torch
from curated_transformers.models import AutoEncoder
from curated_transformers.tokenizers import AutoTokenizer

device = torch.device("cpu")
encoder = AutoEncoder.from_hf_hub(
   name="bert-base-uncased", revision="main", device=device
)
# Set module state to evaluation mode.
encoder.eval()

tokenizer = AutoTokenizer.from_hf_hub(
   name="bert-base-uncased",
   revision="main",
)

input_pieces = tokenizer(
   [
      "Straight jacket fitting a little too tight",
      "Space shuttle, snail shell, merry go round, conveyor belt!",
   ]
)

# Don't allocate gradients since we're only running inference.
with torch.no_grad():
   ids = input_pieces.padded_tensor(pad_left=True, device=device)
   attention_mask = input_pieces.attention_mask(device=device)
   model_output = encoder(ids, attention_mask)

# [batch, seq_len, width]
last_hidden_repr = model_output.last_hidden_layer_state

The ModelOutput instance returned by the encoder contains all of transformer’s outputs, i.e., the hidden representations of all transformer layers and the output of the embedding layer. Decoder models (DecoderModule) and causal language models (CausalLMModule) produce additional outputs such as the key-value cache used during attention calculation (ModelOutputWithCache) and logits (CausalLMOutputWithCache).