zswzswzsw commited on
Commit
2a4552a
·
verified ·
1 Parent(s): ae40651

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. grpo_max_completion.py +248 -0
  2. grpo_offline_run.py +11 -3
grpo_max_completion.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Supervised fine-tuning script for decoder language models.
18
+ CUDA_VISIBLE_DEVICES=1,2,3,4,5 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml grpo_offline_run.py config_grpo_offline.yaml
19
+ """
20
+
21
+ import logging
22
+ import random
23
+ import sys
24
+
25
+ import datasets
26
+ import torch
27
+ import transformers
28
+ from transformers import AutoModelForCausalLM, set_seed
29
+ from trl.data_utils import maybe_apply_chat_template
30
+ from datasets import load_dataset
31
+ from alignment import (
32
+ DataArguments,
33
+ H4ArgumentParser,
34
+ ModelArguments,
35
+ SFTConfig,
36
+ apply_chat_template,
37
+ decontaminate_humaneval,
38
+ get_checkpoint,
39
+ get_datasets,
40
+ get_kbit_device_map,
41
+ get_peft_config,
42
+ get_quantization_config,
43
+ get_tokenizer,
44
+ )
45
+ from trl import SFTTrainer, setup_chat_format
46
+ from trl_012_grpo.grpo_trainer import GRPOTrainer
47
+ from trl_012_grpo.grpo_config import GRPOConfig
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ def main():
53
+ parser = H4ArgumentParser((ModelArguments, DataArguments, GRPOConfig))
54
+ model_args, data_args, training_args = parser.parse()
55
+
56
+ # Set seed for reproducibility
57
+ set_seed(training_args.seed)
58
+
59
+ ###############
60
+ # Setup logging
61
+ ###############
62
+ logging.basicConfig(
63
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
64
+ datefmt="%Y-%m-%d %H:%M:%S",
65
+ handlers=[logging.StreamHandler(sys.stdout)],
66
+ )
67
+ log_level = training_args.get_process_log_level()
68
+ logger.setLevel(log_level)
69
+ datasets.utils.logging.set_verbosity(log_level)
70
+ transformers.utils.logging.set_verbosity(log_level)
71
+ transformers.utils.logging.enable_default_handler()
72
+ transformers.utils.logging.enable_explicit_format()
73
+
74
+ # Log on each process a small summary
75
+ logger.warning(
76
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
77
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
78
+ )
79
+ logger.info(f"Model parameters {model_args}")
80
+ logger.info(f"Data parameters {data_args}")
81
+ logger.info(f"Training/evaluation parameters {training_args}")
82
+
83
+ # Check for last checkpoint
84
+ last_checkpoint = get_checkpoint(training_args)
85
+ if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
86
+ logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
87
+
88
+ ###############
89
+ # Load datasets
90
+ ###############
91
+ raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
92
+ eval_raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
93
+ logger.info(
94
+ f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
95
+ )
96
+ column_names = list(raw_datasets["train"].features)
97
+
98
+ ################
99
+ # Load tokenizer
100
+ ################
101
+ tokenizer = get_tokenizer(model_args, data_args)
102
+
103
+ #######################
104
+ # Load pretrained model
105
+ #######################
106
+ logger.info("*** Load pretrained model ***")
107
+ torch_dtype = (
108
+ model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
109
+ )
110
+ quantization_config = get_quantization_config(model_args)
111
+
112
+ model_kwargs = dict(
113
+ revision=model_args.model_revision,
114
+ trust_remote_code=model_args.trust_remote_code,
115
+ attn_implementation=model_args.attn_implementation,
116
+ torch_dtype=torch_dtype,
117
+ use_cache=False if training_args.gradient_checkpointing else True,
118
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
119
+ quantization_config=quantization_config,
120
+ )
121
+
122
+ model = model_args.model_name_or_path
123
+ # For ChatML we need to add special tokens and resize the embedding layer
124
+ if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path:
125
+ model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
126
+ model, tokenizer = setup_chat_format(model, tokenizer)
127
+ model_kwargs = None
128
+
129
+ #####################
130
+ # Apply chat template
131
+ #####################
132
+ def truncate_string(text, max_length, tokenizer):
133
+ """
134
+ 将字符串转换为 ID 列表,截断超过 max_length 的部分,再将剩余的 ID 转回字符串。
135
+
136
+ Args:
137
+ text (str): 输入的字符串
138
+ max_length (int): 最大允许的长度
139
+ tokenizer: 用于转换的 tokenizer
140
+
141
+ Returns:
142
+ str: 截断后的字符串
143
+ """
144
+ # 将字符串转换为 ID 列表
145
+ input_ids = tokenizer.encode(text, add_special_tokens=False)
146
+
147
+ # 截断 ID 列表
148
+ truncated_ids = input_ids[:max_length]
149
+
150
+ # 将截断后的 ID 列表转回字符串
151
+ truncated_text = tokenizer.decode(truncated_ids, skip_special_tokens=True)
152
+ print('截断前:',text)
153
+ print('截断后: ',truncated_text)
154
+
155
+ return truncated_text
156
+ def modify_completion(example):
157
+ # 将 completion 转换为列表
158
+ example['prompt'] = \
159
+ maybe_apply_chat_template({"prompt": [{"role": "user", "content": example['prompt']}]}, tokenizer=tokenizer)[
160
+ 'prompt']
161
+ new_completions = []
162
+ for length,completion in zip(example['length'],example['completion']):
163
+ if length>training_args.max_completion_length:
164
+ completion = truncate_string(completion,training_args.max_completion_length,tokenizer)
165
+ new_completions.append(completion)
166
+ example['completion'] = new_completions
167
+
168
+ return example
169
+
170
+ raw_datasets = raw_datasets.map(modify_completion)
171
+ eval_raw_datasets = eval_raw_datasets.map(modify_completion)
172
+
173
+
174
+ train_dataset = raw_datasets["train"]
175
+ eval_dataset = eval_raw_datasets["train"]
176
+
177
+ ########################
178
+ # Initialize the Trainer
179
+ ########################
180
+
181
+ # 这里的reward function实际不会被用到
182
+ def reward_len(completions, **kwargs):
183
+ return [-abs(20 - len(completion)) for completion in completions]
184
+
185
+ training_args.model_init_kwargs = model_kwargs
186
+ trainer = GRPOTrainer(
187
+ model=model,
188
+ reward_funcs=reward_len,
189
+ args=training_args,
190
+ train_dataset=train_dataset,
191
+ eval_dataset=eval_dataset,
192
+ )
193
+
194
+ ###############
195
+ # Training loop
196
+ ###############
197
+ logger.info("*** Train ***")
198
+ checkpoint = None
199
+ if training_args.resume_from_checkpoint is not None:
200
+ checkpoint = training_args.resume_from_checkpoint
201
+ elif last_checkpoint is not None:
202
+ checkpoint = last_checkpoint
203
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
204
+ metrics = train_result.metrics
205
+ metrics["train_samples"] = len(train_dataset)
206
+ trainer.log_metrics("train", metrics)
207
+ trainer.save_metrics("train", metrics)
208
+ trainer.save_state()
209
+
210
+ ##################################
211
+ # Save model and create model card
212
+ ##################################
213
+ logger.info("*** Save model ***")
214
+ trainer.save_model(training_args.output_dir)
215
+ logger.info(f"Model saved to {training_args.output_dir}")
216
+
217
+ # Save everything else on main process
218
+ kwargs = {
219
+ "finetuned_from": model_args.model_name_or_path,
220
+ "dataset": list(data_args.dataset_mixer.keys()),
221
+ "dataset_tags": list(data_args.dataset_mixer.keys()),
222
+ "tags": ["alignment-handbook"],
223
+ }
224
+ if trainer.accelerator.is_main_process:
225
+ trainer.create_model_card(**kwargs)
226
+ # Restore k,v cache for fast inference
227
+ trainer.model.config.use_cache = True
228
+ trainer.model.config.save_pretrained(training_args.output_dir)
229
+
230
+ ##########
231
+ # Evaluate
232
+ ##########
233
+ if training_args.do_eval:
234
+ logger.info("*** Evaluate ***")
235
+ metrics = trainer.evaluate()
236
+ metrics["eval_samples"] = len(eval_dataset)
237
+ trainer.log_metrics("eval", metrics)
238
+ trainer.save_metrics("eval", metrics)
239
+
240
+ if training_args.push_to_hub is True:
241
+ logger.info("Pushing to hub...")
242
+ trainer.push_to_hub(**kwargs)
243
+
244
+ logger.info("*** Training complete ***")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()
grpo_offline_run.py CHANGED
@@ -15,7 +15,7 @@
15
  # limitations under the License.
16
  """
17
  Supervised fine-tuning script for decoder language models.
18
- CUDA_VISIBLE_DEVICES=1,2,3,4,5 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml grpo_offline_run.py config_grpo_offline.yaml
19
  """
20
 
21
  import logging
@@ -27,7 +27,7 @@ import torch
27
  import transformers
28
  from transformers import AutoModelForCausalLM, set_seed
29
  from trl.data_utils import maybe_apply_chat_template
30
- from datasets import load_dataset
31
  from alignment import (
32
  DataArguments,
33
  H4ArgumentParser,
@@ -88,7 +88,15 @@ def main():
88
  ###############
89
  # Load datasets
90
  ###############
91
- raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
 
 
 
 
 
 
 
 
92
  eval_raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
93
  logger.info(
94
  f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
 
15
  # limitations under the License.
16
  """
17
  Supervised fine-tuning script for decoder language models.
18
+ CUDA_VISIBLE_DEVICES=2,3,4,5 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml grpo_offline_run.py config_grpo_offline.yaml
19
  """
20
 
21
  import logging
 
27
  import transformers
28
  from transformers import AutoModelForCausalLM, set_seed
29
  from trl.data_utils import maybe_apply_chat_template
30
+ from datasets import load_dataset, Features, Value, Sequence
31
  from alignment import (
32
  DataArguments,
33
  H4ArgumentParser,
 
88
  ###############
89
  # Load datasets
90
  ###############
91
+ features = Features({
92
+ "prompt": Value("large_string"), # prompt 字段可能较长,使用 large_string
93
+ "completion": Sequence(feature=Value("large_string")), # completion 是字符串列表,使用 list<large_string>
94
+ "reward": Sequence(feature=Value("float32")), # reward 是整数列表
95
+ "length": Sequence(feature=Value("int32")), # length 是整数列表
96
+ "instruction_len": Value("int32"), # instruction_len 是整数
97
+ "del_score": Value("float32") # del_score 是浮点数
98
+ })
99
+ raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_train_shuffle.json",features=features)
100
  eval_raw_datasets = load_dataset("json", data_files="/data01/swzhang/dataset/grpo_data_ori/grpo_del_lowscore/shuffle/grpo_test_shuffle.json")
101
  logger.info(
102
  f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"