Utilities
Context Managers
- curated_transformers.layers.enable_torch_sdp(use_torch_sdp=True)
Enables Torch scaled dot product attention.
Torch provides an implementation of scaled dot product attention that has many optimizations. For instance, in some scenarios Flash Attention is applied (Dao et al., 2022). We do not use the Torch implementation by default, because it is still in beta.
This context manager enables use of the Torch implementation of scaled dot product attention.
with enable_torch_sdp(): Y = bert_encoder(X)
- curated_transformers.util.use_nvtx_ranges_for_forward_pass(module)
Recursively applies NVTX ranges to the forward pass operation of the provided module. The ranges will be recorded during an Nsight profiling session.
- Parameters:
module (
Module) – Top-level module to which the ranges are applied recursively.
Hugging Face
Loading Models from Hugging Face Hub
These mixin classes are used to implement support for loading models and tokenizers directly from Hugging Face Hub.
Attention
To download models hosted in private repositories, the user will first need to set up their authentication token using the Hugging Face Hub client.
- class curated_transformers.models.FromHFHub
Mixin class for downloading models from Hugging Face Hub.
A module using this mixin can implement the
convert_hf_state_dictandfrom_hf_configmethods. The mixin will then provide thefrom_hf_hubmethod to download a model from the Hugging Face Hub.- abstract classmethod config_from_hf(hf_config)
Convert a Hugging Face model configuration to the module’s configuration.
- Parameters:
hf_config (
Mapping[str,Any]) – The Hugging Face model configuration.- Return type:
TypeVar(ConfigT, bound=TransformerConfig)- Returns:
The converted Curated Transformer configuration.
- abstract classmethod config_to_hf(curated_config)
Convert the module’s configuration to the a Hugging Face model configuration.
- Parameters:
curated_config (
TypeVar(ConfigT, bound=TransformerConfig)) – The Curated Transformer model configuration.- Return type:
- Returns:
The converted Hugging Face configuration.
- classmethod convert_hf_state_dict(params)
Alias for
state_dict_from_hf().
- classmethod from_fsspec(*, fs, model_path, fsspec_args=None, device=None, quantization_config=None)
Construct a module and load its parameters from a fsspec filesystem.
- Parameters:
fs (
AbstractFileSystem) – The filesystem to load the model from.model_path (
str) – The path of the model on the filesystem.fsspec_args (
Optional[FsspecArgs]) – Implementation-specific keyword arguments to pass to fsspec filesystem operations.device (
Optional[device]) – Device on which the model is initialized.quantization_config (
Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.
- Return type:
TypeVar(Self, bound= FromHFHub)- Returns:
Module with the parameters loaded.
- abstract classmethod from_hf_config(*, hf_config, device=None)
Create the module from a Hugging Face model JSON-deserialized model configuration.
- classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)
Construct a module and load its parameters from Hugging Face Hub.
- Parameters:
name (
str) – Model name.revision (
str) – Model revision.device (
Optional[device]) – Device on which the model is initialized.quantization_config (
Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.
- Return type:
TypeVar(Self, bound= FromHFHub)- Returns:
Module with the parameters loaded.
- classmethod from_hf_hub_to_cache(*, name, revision='main')
Download the model’s weights from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the model will read the weights from disk. If the weights are already cached, this is a no-op.
- classmethod from_repo(*, repo, device=None, quantization_config=None)
Construct and load a model from a repository.
- Parameters:
repository – The repository to load from.
device (
Optional[device]) – Device on which to initialize the model.quantization_config (
Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.
- Return type:
TypeVar(Self, bound= FromHFHub)- Returns:
Loaded model.
- abstract classmethod is_supported(config)
Check if the model with the given configuration is supported by this class.
- abstract classmethod state_dict_from_hf(params)
Convert a state dict of a Hugging Face model to a valid state dict for the module.
- class curated_transformers.generation.FromHFHub
Mixin class for downloading generators from Hugging Face Hub.
It automatically infers the tokenizer and model type and loads the parameters and the configuration from the hub.
- abstract classmethod from_hf_hub(*, name, revision='main', device=None, quantization_config=None)
Construct a generator and load its parameters from Hugging Face Hub.
- Parameters:
name (
str) – Model name.revision (
str) – Model revision.device (
Optional[device]) – Device on which to initialize the model.quantization_config (
Optional[BitsAndBytesConfig]) – Configuration for loading quantized weights.
- Return type:
TypeVar(Self, bound= FromHFHub)- Returns:
Generator with the parameters loaded.
- abstract classmethod from_hf_hub_to_cache(*, name, revision='main')
Download the generator’s model and tokenizer from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the generator will load the model and the tokenizer from disk.
- class curated_transformers.tokenizers.FromHFHub
Mixin class for downloading tokenizers from Hugging Face Hub.
It directly queries the Hugging Face Hub to load the tokenizer from its configuration file.
- classmethod from_fsspec(*, fs, model_path, fsspec_args=None)
Construct a tokenizer and load its parameters from an fsspec filesystem.
- classmethod from_hf_hub(*, name, revision='main')
Construct a tokenizer and load its parameters from Hugging Face Hub.
- abstract classmethod from_hf_hub_to_cache(*, name, revision='main')
Download the tokenizer’s serialized model, configuration and vocab files from Hugging Face Hub into the local Hugging Face cache directory. Subsequent loading of the tokenizer will read the files from disk. If the files are already cached, this is a no-op.