aifeifei798 commited on
Commit
035246c
·
verified ·
1 Parent(s): 701871d

Update bit4-chat.py

Browse files
Files changed (1) hide show
  1. bit4-chat.py +38 -38
bit4-chat.py CHANGED
@@ -1,38 +1,38 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2
- import torch
3
-
4
- # Configure quantization parameters
5
- quantization_config = BitsAndBytesConfig(
6
- load_in_4bit=True, # Load the model weights in 4-bit precision
7
- bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation
8
- bnb_4bit_quant_type="nf4", # Use "nf4" quantization type
9
- bnb_4bit_use_double_quant=True, # Enable double quantization
10
- )
11
-
12
- # Define the model name and path for the quantized model
13
- model_name = "./Llama-3.1-Nemotron-Nano-8B-v1-bnb-4bit"
14
-
15
- # Load the quantized model with the specified configuration
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_name,
18
- quantization_config=quantization_config,
19
- device_map="auto" # Automatically allocate devices
20
- )
21
-
22
- # Load the tokenizer associated with the model
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
-
25
- # Determine the device where the model is located
26
- device = model.device
27
-
28
- # Prepare input text and move it to the same device as the model
29
- input_text = "Once upon a time"
30
- inputs = tokenizer(input_text, return_tensors="pt").to(device)
31
-
32
- # Perform inference
33
- with torch.no_grad():
34
- outputs = model.generate(**inputs, max_length=50)
35
-
36
- # Decode the generated text
37
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
- print(generated_text)
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
2
+ import torch
3
+
4
+ # Configure quantization parameters
5
+ quantization_config = BitsAndBytesConfig(
6
+ load_in_4bit=True, # Load the model weights in 4-bit precision
7
+ bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation
8
+ bnb_4bit_quant_type="nf4", # Use "nf4" quantization type
9
+ bnb_4bit_use_double_quant=True, # Enable double quantization
10
+ )
11
+
12
+ # Define the model name and path for the quantized model
13
+ model_name = "nvidia/Llama-3.1-Nemotron-Nano-8B-v1-bnb-4bit"
14
+
15
+ # Load the quantized model with the specified configuration
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ quantization_config=quantization_config,
19
+ device_map="auto" # Automatically allocate devices
20
+ )
21
+
22
+ # Load the tokenizer associated with the model
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+
25
+ # Determine the device where the model is located
26
+ device = model.device
27
+
28
+ # Prepare input text and move it to the same device as the model
29
+ input_text = "Once upon a time"
30
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
31
+
32
+ # Perform inference
33
+ with torch.no_grad():
34
+ outputs = model.generate(**inputs, max_length=50)
35
+
36
+ # Decode the generated text
37
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ print(generated_text)