aifeifei798's picture
Update tobit4.py
8a7c610 verified
raw
history blame
1.35 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import bitsandbytes as bnb
# Define the model name and path
model_name = "nvidia/Llama-3.1-Nemotron-Nano-8B-v1"
# Configure quantization parameters
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Load the model weights in 4-bit precision
bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation
bnb_4bit_quant_type="nf4", # Use "nf4" quantization type
bnb_4bit_use_double_quant=True, # Enable double quantization
llm_int8_skip_modules=[ # Specify modules to skip during quantization
"lm_head",
"multi_modal_projector",
"merger",
"modality_projection",
"model.layers.1.mlp"
],
)
# Load the pre-trained model with the specified quantization configuration
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto" # Automatically allocate devices
)
# Load the tokenizer associated with the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Save the quantized model and tokenizer to a specified directory
model.save_pretrained("Llama-3.1-Nemotron-Nano-8B-v1-bnb-4bit")
tokenizer.save_pretrained("Llama-3.1-Nemotron-Nano-8B-v1-bnb-4bit")