Upload folder using huggingface_hub
Browse files- grpo_max_completion.py +248 -0
- grpo_offline_run.py +11 -3
grpo_max_completion.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Supervised fine-tuning script for decoder language models.
|
18 |
+
CUDA_VISIBLE_DEVICES=1,2,3,4,5 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml grpo_offline_run.py config_grpo_offline.yaml
|
19 |
+
"""
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import random
|
23 |
+
import sys
|
24 |
+
|
25 |
+
import datasets
|
26 |
+
import torch
|
27 |
+
import transformers
|
28 |
+
from transformers import AutoModelForCausalLM, set_seed
|
29 |
+
from trl.data_utils import maybe_apply_chat_template
|
30 |
+
from datasets import load_dataset
|
31 |
+
from alignment import (
|
32 |
+
DataArguments,
|
33 |
+
H4ArgumentParser,
|
34 |
+
ModelArguments,
|
35 |
+
SFTConfig,
|
36 |
+
apply_chat_template,
|
37 |
+
decontaminate_humaneval,
|
38 |
+
get_checkpoint,
|
39 |
+
get_datasets,
|
40 |
+
get_kbit_device_map,
|
41 |
+
get_peft_config,
|
42 |
+
get_quantization_config,
|
43 |
+
get_tokenizer,
|
44 |
+
)
|
45 |
+
from trl import SFTTrainer, setup_chat_format
|
46 |
+
from trl_012_grpo.grpo_trainer import GRPOTrainer
|
47 |
+
from trl_012_grpo.grpo_config import GRPOConfig
|
48 |
+
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
def main():
|
53 |
+
parser = H4ArgumentParser((ModelArguments, DataArguments, GRPOConfig))
|
54 |
+
model_args, data_args, training_args = parser.parse()
|
55 |
+
|
56 |
+
# Set seed for reproducibility
|
57 |
+
set_seed(training_args.seed)
|
58 |
+
|
59 |
+
###############
|
60 |
+
# Setup logging
|
61 |
+
###############
|
62 |
+
logging.basicConfig(
|
63 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
64 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
65 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
66 |
+
)
|
67 |
+
log_level = training_args.get_process_log_level()
|
68 |
+
logger.setLevel(log_level)
|
69 |
+
datasets.utils.logging.set_verbosity(log_level)
|
70 |
+
transformers.utils.logging.set_verbosity(log_level)
|
71 |
+
transformers.utils.logging.enable_default_handler()
|
72 |
+
transformers.utils.logging.enable_explicit_format()
|
73 |
+
|
74 |
+
# Log on each process a small summary
|
75 |
+
logger.warning(
|
76 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
77 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
78 |
+
)
|
79 |
+
logger.info(f"Model parameters {model_args}")
|
80 |
+
logger.info(f"Data parameters {data_args}")
|
81 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
82 |
+
|
83 |
+
# Check for last checkpoint
|
84 |
+
last_checkpoint = get_checkpoint(training_args)
|
85 |
+
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
86 |
+
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
87 |
+
|
88 |
+
###############
|
89 |
+
# Load datasets
|
90 |
+
###############
|
91 |
+
raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
|
92 |
+
eval_raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
|
93 |
+
logger.info(
|
94 |
+
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
95 |
+
)
|
96 |
+
column_names = list(raw_datasets["train"].features)
|
97 |
+
|
98 |
+
################
|
99 |
+
# Load tokenizer
|
100 |
+
################
|
101 |
+
tokenizer = get_tokenizer(model_args, data_args)
|
102 |
+
|
103 |
+
#######################
|
104 |
+
# Load pretrained model
|
105 |
+
#######################
|
106 |
+
logger.info("*** Load pretrained model ***")
|
107 |
+
torch_dtype = (
|
108 |
+
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
109 |
+
)
|
110 |
+
quantization_config = get_quantization_config(model_args)
|
111 |
+
|
112 |
+
model_kwargs = dict(
|
113 |
+
revision=model_args.model_revision,
|
114 |
+
trust_remote_code=model_args.trust_remote_code,
|
115 |
+
attn_implementation=model_args.attn_implementation,
|
116 |
+
torch_dtype=torch_dtype,
|
117 |
+
use_cache=False if training_args.gradient_checkpointing else True,
|
118 |
+
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
119 |
+
quantization_config=quantization_config,
|
120 |
+
)
|
121 |
+
|
122 |
+
model = model_args.model_name_or_path
|
123 |
+
# For ChatML we need to add special tokens and resize the embedding layer
|
124 |
+
if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path:
|
125 |
+
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
126 |
+
model, tokenizer = setup_chat_format(model, tokenizer)
|
127 |
+
model_kwargs = None
|
128 |
+
|
129 |
+
#####################
|
130 |
+
# Apply chat template
|
131 |
+
#####################
|
132 |
+
def truncate_string(text, max_length, tokenizer):
|
133 |
+
"""
|
134 |
+
将字符串转换为 ID 列表,截断超过 max_length 的部分,再将剩余的 ID 转回字符串。
|
135 |
+
|
136 |
+
Args:
|
137 |
+
text (str): 输入的字符串
|
138 |
+
max_length (int): 最大允许的长度
|
139 |
+
tokenizer: 用于转换的 tokenizer
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
str: 截断后的字符串
|
143 |
+
"""
|
144 |
+
# 将字符串转换为 ID 列表
|
145 |
+
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
146 |
+
|
147 |
+
# 截断 ID 列表
|
148 |
+
truncated_ids = input_ids[:max_length]
|
149 |
+
|
150 |
+
# 将截断后的 ID 列表转回字符串
|
151 |
+
truncated_text = tokenizer.decode(truncated_ids, skip_special_tokens=True)
|
152 |
+
print('截断前:',text)
|
153 |
+
print('截断后: ',truncated_text)
|
154 |
+
|
155 |
+
return truncated_text
|
156 |
+
def modify_completion(example):
|
157 |
+
# 将 completion 转换为列表
|
158 |
+
example['prompt'] = \
|
159 |
+
maybe_apply_chat_template({"prompt": [{"role": "user", "content": example['prompt']}]}, tokenizer=tokenizer)[
|
160 |
+
'prompt']
|
161 |
+
new_completions = []
|
162 |
+
for length,completion in zip(example['length'],example['completion']):
|
163 |
+
if length>training_args.max_completion_length:
|
164 |
+
completion = truncate_string(completion,training_args.max_completion_length,tokenizer)
|
165 |
+
new_completions.append(completion)
|
166 |
+
example['completion'] = new_completions
|
167 |
+
|
168 |
+
return example
|
169 |
+
|
170 |
+
raw_datasets = raw_datasets.map(modify_completion)
|
171 |
+
eval_raw_datasets = eval_raw_datasets.map(modify_completion)
|
172 |
+
|
173 |
+
|
174 |
+
train_dataset = raw_datasets["train"]
|
175 |
+
eval_dataset = eval_raw_datasets["train"]
|
176 |
+
|
177 |
+
########################
|
178 |
+
# Initialize the Trainer
|
179 |
+
########################
|
180 |
+
|
181 |
+
# 这里的reward function实际不会被用到
|
182 |
+
def reward_len(completions, **kwargs):
|
183 |
+
return [-abs(20 - len(completion)) for completion in completions]
|
184 |
+
|
185 |
+
training_args.model_init_kwargs = model_kwargs
|
186 |
+
trainer = GRPOTrainer(
|
187 |
+
model=model,
|
188 |
+
reward_funcs=reward_len,
|
189 |
+
args=training_args,
|
190 |
+
train_dataset=train_dataset,
|
191 |
+
eval_dataset=eval_dataset,
|
192 |
+
)
|
193 |
+
|
194 |
+
###############
|
195 |
+
# Training loop
|
196 |
+
###############
|
197 |
+
logger.info("*** Train ***")
|
198 |
+
checkpoint = None
|
199 |
+
if training_args.resume_from_checkpoint is not None:
|
200 |
+
checkpoint = training_args.resume_from_checkpoint
|
201 |
+
elif last_checkpoint is not None:
|
202 |
+
checkpoint = last_checkpoint
|
203 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
204 |
+
metrics = train_result.metrics
|
205 |
+
metrics["train_samples"] = len(train_dataset)
|
206 |
+
trainer.log_metrics("train", metrics)
|
207 |
+
trainer.save_metrics("train", metrics)
|
208 |
+
trainer.save_state()
|
209 |
+
|
210 |
+
##################################
|
211 |
+
# Save model and create model card
|
212 |
+
##################################
|
213 |
+
logger.info("*** Save model ***")
|
214 |
+
trainer.save_model(training_args.output_dir)
|
215 |
+
logger.info(f"Model saved to {training_args.output_dir}")
|
216 |
+
|
217 |
+
# Save everything else on main process
|
218 |
+
kwargs = {
|
219 |
+
"finetuned_from": model_args.model_name_or_path,
|
220 |
+
"dataset": list(data_args.dataset_mixer.keys()),
|
221 |
+
"dataset_tags": list(data_args.dataset_mixer.keys()),
|
222 |
+
"tags": ["alignment-handbook"],
|
223 |
+
}
|
224 |
+
if trainer.accelerator.is_main_process:
|
225 |
+
trainer.create_model_card(**kwargs)
|
226 |
+
# Restore k,v cache for fast inference
|
227 |
+
trainer.model.config.use_cache = True
|
228 |
+
trainer.model.config.save_pretrained(training_args.output_dir)
|
229 |
+
|
230 |
+
##########
|
231 |
+
# Evaluate
|
232 |
+
##########
|
233 |
+
if training_args.do_eval:
|
234 |
+
logger.info("*** Evaluate ***")
|
235 |
+
metrics = trainer.evaluate()
|
236 |
+
metrics["eval_samples"] = len(eval_dataset)
|
237 |
+
trainer.log_metrics("eval", metrics)
|
238 |
+
trainer.save_metrics("eval", metrics)
|
239 |
+
|
240 |
+
if training_args.push_to_hub is True:
|
241 |
+
logger.info("Pushing to hub...")
|
242 |
+
trainer.push_to_hub(**kwargs)
|
243 |
+
|
244 |
+
logger.info("*** Training complete ***")
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
main()
|
grpo_offline_run.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
# limitations under the License.
|
16 |
"""
|
17 |
Supervised fine-tuning script for decoder language models.
|
18 |
-
CUDA_VISIBLE_DEVICES=
|
19 |
"""
|
20 |
|
21 |
import logging
|
@@ -27,7 +27,7 @@ import torch
|
|
27 |
import transformers
|
28 |
from transformers import AutoModelForCausalLM, set_seed
|
29 |
from trl.data_utils import maybe_apply_chat_template
|
30 |
-
from datasets import load_dataset
|
31 |
from alignment import (
|
32 |
DataArguments,
|
33 |
H4ArgumentParser,
|
@@ -88,7 +88,15 @@ def main():
|
|
88 |
###############
|
89 |
# Load datasets
|
90 |
###############
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
eval_raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
|
93 |
logger.info(
|
94 |
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
|
|
15 |
# limitations under the License.
|
16 |
"""
|
17 |
Supervised fine-tuning script for decoder language models.
|
18 |
+
CUDA_VISIBLE_DEVICES=2,3,4,5 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml grpo_offline_run.py config_grpo_offline.yaml
|
19 |
"""
|
20 |
|
21 |
import logging
|
|
|
27 |
import transformers
|
28 |
from transformers import AutoModelForCausalLM, set_seed
|
29 |
from trl.data_utils import maybe_apply_chat_template
|
30 |
+
from datasets import load_dataset, Features, Value, Sequence
|
31 |
from alignment import (
|
32 |
DataArguments,
|
33 |
H4ArgumentParser,
|
|
|
88 |
###############
|
89 |
# Load datasets
|
90 |
###############
|
91 |
+
features = Features({
|
92 |
+
"prompt": Value("large_string"), # prompt 字段可能较长,使用 large_string
|
93 |
+
"completion": Sequence(feature=Value("large_string")), # completion 是字符串列表,使用 list<large_string>
|
94 |
+
"reward": Sequence(feature=Value("float32")), # reward 是整数列表
|
95 |
+
"length": Sequence(feature=Value("int32")), # length 是整数列表
|
96 |
+
"instruction_len": Value("int32"), # instruction_len 是整数
|
97 |
+
"del_score": Value("float32") # del_score 是浮点数
|
98 |
+
})
|
99 |
+
raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_train_shuffle.json",features=features)
|
100 |
eval_raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
|
101 |
logger.info(
|
102 |
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|