Update README.md
Browse files
README.md
CHANGED
@@ -79,6 +79,74 @@ generated_text = generate_text(prompt)
|
|
79 |
print("\nGenerated Text:")
|
80 |
print(generated_text)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
# jsonlファイルの出力方法は以下の通りです。
|
84 |
import json
|
|
|
79 |
print("\nGenerated Text:")
|
80 |
print(generated_text)
|
81 |
|
82 |
+
# 量子化
|
83 |
+
|
84 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
85 |
+
from peft import PeftModel
|
86 |
+
import torch
|
87 |
+
from transformers import BitsAndBytesConfig
|
88 |
+
|
89 |
+
ベースモデル ID とアダプタファイルパス
|
90 |
+
base_model_id = "llm-jp/llm-jp-3-13b"
|
91 |
+
adapter_model_path = "path/to/"
|
92 |
+
|
93 |
+
デバイス設定
|
94 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
95 |
+
|
96 |
+
量子化の設定
|
97 |
+
bnb_config = BitsAndBytesConfig(
|
98 |
+
load_in_4bit=True, # 4-bit 量子化を有効化
|
99 |
+
bnb_4bit_use_double_quant=True,
|
100 |
+
bnb_4bit_quant_type="nf4", # 量子化スキーム
|
101 |
+
bnb_4bit_compute_dtype=torch.float16, # 推論時の計算精度
|
102 |
+
)
|
103 |
+
|
104 |
+
トークナイザーのロード
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
|
106 |
+
|
107 |
+
ベースモデルのロード(量子化設定を使用)
|
108 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
109 |
+
base_model_id,
|
110 |
+
quantization_config=bnb_config,
|
111 |
+
device_map="auto", # 自動的に GPU に割り当て
|
112 |
+
)
|
113 |
+
|
114 |
+
アダプタの読み込み
|
115 |
+
model = PeftModel.from_pretrained(base_model, adapter_model_path).to(device)
|
116 |
+
|
117 |
+
`pad_token_id` の設定(トークナイザーから取得)
|
118 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
119 |
+
|
120 |
+
推論関数
|
121 |
+
def generate_text(prompt, max_length=256, temperature=0.7):
|
122 |
+
# トークナイズして `attention_mask` を設定し、max_length を適用
|
123 |
+
inputs = tokenizer(
|
124 |
+
prompt,
|
125 |
+
return_tensors="pt",
|
126 |
+
padding=True,
|
127 |
+
truncation=True,
|
128 |
+
max_length=max_length # 最大トークン数を制限
|
129 |
+
).to(device)
|
130 |
+
|
131 |
+
outputs = model.generate(
|
132 |
+
inputs["input_ids"],
|
133 |
+
attention_mask=inputs["attention_mask"],
|
134 |
+
max_length=max_length,
|
135 |
+
temperature=temperature,
|
136 |
+
do_sample=True,
|
137 |
+
top_k=50,
|
138 |
+
top_p=0.9,
|
139 |
+
pad_token_id=tokenizer.pad_token_id # 安全な動作のため明示的に指定
|
140 |
+
)
|
141 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
142 |
+
|
143 |
+
テストプロンプト
|
144 |
+
prompt = "日本の経済について説明してください。"
|
145 |
+
print("Generating text...")
|
146 |
+
generated_text = generate_text(prompt, max_length=256) # 最大長さを明示的に指定
|
147 |
+
print("\nGenerated Text:")
|
148 |
+
print(generated_text)
|
149 |
+
|
150 |
|
151 |
# jsonlファイルの出力方法は以下の通りです。
|
152 |
import json
|