|
--- |
|
library_name: transformers |
|
datasets: |
|
- Sunbird/salt |
|
--- |
|
|
|
# 8-Bit Quantized NLLB Model |
|
|
|
This is an 8-bit quantized version of the Sunbird NLLB (No Language Left Behind) model, built upom [translate-nllb-1.3b-salt-8bit](https://huggingface.co/Sunbird/translate-nllb-1.3b-salt-8bit) implementation. The quantization reduces the model size and accelerates inference while maintaining a balance between efficiency and translation quality. |
|
|
|
## Model Overview |
|
|
|
The model has been quantized to 8 bits to: |
|
- **Reduce Memory Footprint:** Facilitate deployment on devices with limited resources. |
|
- **Improve Inference Speed:** Enable faster translation with minimal compromise on performance. |
|
|
|
## How to Use |
|
|
|
Below is an example script demonstrating how to load the 8-bit quantized model, perform translation, and decode the output: |
|
|
|
make sure to install bits and bytes |
|
|
|
``` |
|
pip install -U bitsandbytes |
|
``` |
|
|
|
```python |
|
import torch |
|
import transformers |
|
|
|
# Load the 8-bit quantized model and tokenizer |
|
model_8bit = transformers.M2M100ForConditionalGeneration.from_pretrained( |
|
"Sunbird/translate-nllb-1.3b-salt-8bit", |
|
device_map="auto" |
|
) |
|
tokenizer = transformers.NllbTokenizer.from_pretrained("Sunbird/translate-nllb-1.3b-salt") |
|
|
|
# Define the text and language parameters |
|
text = 'Where is the hospital?' |
|
source_language = 'eng' |
|
target_language = 'lug' |
|
|
|
# Mapping for language tokens |
|
language_tokens = { |
|
'eng': 256047, |
|
'ach': 256111, |
|
'lgg': 256008, |
|
'lug': 256110, |
|
'nyn': 256002, |
|
'teo': 256006, |
|
} |
|
|
|
# Prepare device and tokenize the input text |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
inputs['input_ids'][0][0] = language_tokens[source_language] |
|
|
|
# Generate the translation with beam search |
|
translated_tokens = model_8bit.to(device).generate( |
|
**inputs, |
|
forced_bos_token_id=language_tokens[target_language], |
|
max_length=100, |
|
num_beams=5, |
|
) |
|
|
|
# Decode and print the translated result |
|
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
print(result) |
|
# Expected output: "Eddwaliro liri ludda wa?" |
|
|