akera's picture
Update README.md
3e1bac4 verified
---
library_name: transformers
datasets:
- Sunbird/salt
---
# 8-Bit Quantized NLLB Model
This is an 8-bit quantized version of the Sunbird NLLB (No Language Left Behind) model, built upom [translate-nllb-1.3b-salt-8bit](https://huggingface.co/Sunbird/translate-nllb-1.3b-salt-8bit) implementation. The quantization reduces the model size and accelerates inference while maintaining a balance between efficiency and translation quality.
## Model Overview
The model has been quantized to 8 bits to:
- **Reduce Memory Footprint:** Facilitate deployment on devices with limited resources.
- **Improve Inference Speed:** Enable faster translation with minimal compromise on performance.
## How to Use
Below is an example script demonstrating how to load the 8-bit quantized model, perform translation, and decode the output:
make sure to install bits and bytes
```
pip install -U bitsandbytes
```
```python
import torch
import transformers
# Load the 8-bit quantized model and tokenizer
model_8bit = transformers.M2M100ForConditionalGeneration.from_pretrained(
"Sunbird/translate-nllb-1.3b-salt-8bit",
device_map="auto"
)
tokenizer = transformers.NllbTokenizer.from_pretrained("Sunbird/translate-nllb-1.3b-salt")
# Define the text and language parameters
text = 'Where is the hospital?'
source_language = 'eng'
target_language = 'lug'
# Mapping for language tokens
language_tokens = {
'eng': 256047,
'ach': 256111,
'lgg': 256008,
'lug': 256110,
'nyn': 256002,
'teo': 256006,
}
# Prepare device and tokenize the input text
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(text, return_tensors="pt").to(device)
inputs['input_ids'][0][0] = language_tokens[source_language]
# Generate the translation with beam search
translated_tokens = model_8bit.to(device).generate(
**inputs,
forced_bos_token_id=language_tokens[target_language],
max_length=100,
num_beams=5,
)
# Decode and print the translated result
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
print(result)
# Expected output: "Eddwaliro liri ludda wa?"