Skip to main content

Compute dtype (precision) and storage dtype

Various compute precisions are supported. Below is a quick recap of the current cases.

How to configure

It's important to note that compute precision does not necessarily reflect model parameters dtype. With this considered, compute precision can be configured by setting the compute_dtype field. From that, and other optimization settings (or specicic cases), the storage_dtype computed field is deduced. This is different from the specific quantization logic configured via quant_layers and quant_type. If such quantization is enabled, precision is still taken into account for non quantized components.

Note: the compute_dtype field can take both str and torch.dtype input types. An str input is validated to the corresponding torch.dtype via a custom mapping (see eole.config.common.RunningConfig.compute_dtype).

Available modes

Full precision

compute_dtype: {fp32, torch.float32} Standard float precision.

Note: flash attention is not compatible with float32 precision.

Half precision

compute_dtype: {fp16, torch.float16}

In most cases, the main model storage_dtype will be torch.float32, and some parameters will be automatically casted to torch.float16 with torch Automatic Mixed Precision.

Note: this means that checkpoints will be stored in torch.float32 in the amp case.

An exception is when using the fusedadam optimizer, which is more efficient using the legacy apex implementation. This relies on the legacy FP16_Optimizer which requires swiching the model to torch.float16 upstream.

BFloat16

compute_dtype: {bf16, torch.bfloat16}

See bfloat16 floating-point format for specificities.

For now, the logic is the same as the torch.float16 case with torch.amp. This is experimental and has not been extensively tested. Some specific implementations might be explored, e.g. this adapted AdamW implementation.

Int8

compute_dtype: {int8, torch.int8}

This specific setting is only valid for CPU prediction, to enable Dynamic Quantization.

In that case, storage_dtype will initially be torch.float32, and the model will then be quantized to torch.qint8 with torch.quantization.quantize_dynamic.