Molchevsky commited on
Commit
15746d5
Β·
verified Β·
1 Parent(s): f00944d

Upload llama_finetuning.py

Browse files
Files changed (1) hide show
  1. llama_finetuning.py +419 -0
llama_finetuning.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ TrainingArguments,
8
+ Trainer,
9
+ BitsAndBytesConfig,
10
+ DataCollatorForLanguageModeling
11
+ )
12
+ from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
13
+ from datasets import Dataset
14
+ import warnings
15
+ import glob
16
+
17
+ # Suppress warnings
18
+ warnings.filterwarnings("ignore", category=FutureWarning)
19
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
20
+
21
+ def load_jsonl_data(data_dir):
22
+ """Load conversation data from all JSONL files in the specified directory"""
23
+ conversations = []
24
+
25
+ # Find all JSONL files in the directory
26
+ jsonl_files = glob.glob(os.path.join(data_dir, "*.jsonl"))
27
+
28
+ if not jsonl_files:
29
+ print(f"⚠️ No JSONL files found in {data_dir}")
30
+ return []
31
+
32
+ print(f"Found {len(jsonl_files)} JSONL files:")
33
+ for file in jsonl_files:
34
+ print(f" β€’ {os.path.basename(file)}")
35
+
36
+ # Load data from each file
37
+ for file_path in jsonl_files:
38
+ try:
39
+ with open(file_path, 'r', encoding='utf-8') as f:
40
+ for line_num, line in enumerate(f, 1):
41
+ line = line.strip()
42
+ if not line:
43
+ continue
44
+
45
+ try:
46
+ data = json.loads(line)
47
+ if 'messages' in data:
48
+ conversations.append(data['messages'])
49
+ else:
50
+ print(f"⚠️ Skipping line {line_num} in {file_path}: no 'messages' field")
51
+ except json.JSONDecodeError as e:
52
+ print(f"⚠️ Skipping invalid JSON on line {line_num} in {file_path}: {e}")
53
+
54
+ except Exception as e:
55
+ print(f"❌ Error reading file {file_path}: {e}")
56
+
57
+ print(f"Loaded {len(conversations)} conversations from {data_dir}")
58
+ return conversations
59
+
60
+ def format_conversation_for_training(messages):
61
+ """
62
+ Format a conversation with system, user, and assistant messages for Llama training
63
+
64
+ Args:
65
+ messages: List of message dictionaries with 'role' and 'content' keys
66
+
67
+ Returns:
68
+ Formatted string ready for training
69
+ """
70
+ formatted_parts = ["<|begin_of_text|>"]
71
+
72
+ for message in messages:
73
+ role = message.get('role', '').lower()
74
+ content = message.get('content', '').strip()
75
+
76
+ if not content:
77
+ continue
78
+
79
+ if role == 'system':
80
+ formatted_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>")
81
+ elif role == 'user':
82
+ formatted_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>")
83
+ elif role == 'assistant':
84
+ formatted_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>")
85
+ else:
86
+ print(f"⚠️ Unknown role '{role}', skipping message")
87
+
88
+ return "".join(formatted_parts)
89
+
90
+ def tokenize_function(examples, tokenizer, max_length=1024):
91
+ """Tokenize the conversation examples"""
92
+ # Tokenize inputs
93
+ tokenized = tokenizer(
94
+ examples["text"],
95
+ truncation=True,
96
+ padding="max_length",
97
+ max_length=max_length,
98
+ return_tensors=None # Don't return tensors here, let the collator handle it
99
+ )
100
+
101
+ # For causal language modeling, labels are the same as input_ids
102
+ tokenized["labels"] = tokenized["input_ids"].copy()
103
+
104
+ return tokenized
105
+
106
+ def prepare_dataset(conversations, tokenizer, max_length=1024):
107
+ """Prepare dataset for training from conversation data"""
108
+ formatted_texts = []
109
+
110
+ print("πŸ“ Processing conversations...")
111
+ for i, messages in enumerate(conversations):
112
+ if not messages:
113
+ print(f"⚠️ Skipping empty conversation {i+1}")
114
+ continue
115
+
116
+ # Validate conversation structure
117
+ has_system = any(msg.get('role') == 'system' for msg in messages)
118
+ has_user = any(msg.get('role') == 'user' for msg in messages)
119
+ has_assistant = any(msg.get('role') == 'assistant' for msg in messages)
120
+
121
+ if not (has_user and has_assistant):
122
+ print(f"⚠️ Skipping conversation {i+1}: missing user or assistant message")
123
+ continue
124
+
125
+ if not has_system:
126
+ print(f"⚠️ Conversation {i+1} has no system message")
127
+
128
+ # Format the conversation
129
+ formatted_text = format_conversation_for_training(messages)
130
+
131
+ if len(formatted_text.strip()) > 0:
132
+ formatted_texts.append(formatted_text)
133
+ else:
134
+ print(f"⚠️ Skipping empty formatted conversation {i+1}")
135
+
136
+ if not formatted_texts:
137
+ raise ValueError("No valid conversations found! Please check your JSONL files.")
138
+
139
+ print(f"βœ… Successfully processed {len(formatted_texts)} conversations")
140
+
141
+ # Show a sample formatted conversation
142
+ if formatted_texts:
143
+ print("\nπŸ“‹ Sample formatted conversation:")
144
+ print("-" * 80)
145
+ sample = formatted_texts[0]
146
+ print(sample[:500] + "..." if len(sample) > 500 else sample)
147
+ print("-" * 80)
148
+
149
+ # Create Hugging Face dataset
150
+ dataset = Dataset.from_dict({"text": formatted_texts})
151
+
152
+ # Tokenize the dataset
153
+ tokenized_dataset = dataset.map(
154
+ lambda examples: tokenize_function(examples, tokenizer, max_length),
155
+ batched=True,
156
+ remove_columns=dataset.column_names,
157
+ desc="Tokenizing conversations"
158
+ )
159
+
160
+ return tokenized_dataset
161
+
162
+ def setup_model_and_tokenizer(model_path):
163
+ """Setup model with quantization and tokenizer"""
164
+
165
+ # Quantization config for 4-bit training
166
+ bnb_config = BitsAndBytesConfig(
167
+ load_in_4bit=True,
168
+ bnb_4bit_use_double_quant=True,
169
+ bnb_4bit_quant_type="nf4",
170
+ bnb_4bit_compute_dtype=torch.bfloat16,
171
+ )
172
+
173
+ # Load tokenizer
174
+ tokenizer = AutoTokenizer.from_pretrained(
175
+ model_path,
176
+ trust_remote_code=True,
177
+ padding_side="right"
178
+ )
179
+
180
+ # Add pad token if it doesn't exist
181
+ if tokenizer.pad_token is None:
182
+ tokenizer.pad_token = tokenizer.eos_token
183
+ tokenizer.pad_token_id = tokenizer.eos_token_id
184
+
185
+ # Load model with quantization
186
+ try:
187
+ # Try to use Flash Attention 2 if available and compatible
188
+ model = AutoModelForCausalLM.from_pretrained(
189
+ model_path,
190
+ quantization_config=bnb_config,
191
+ device_map="auto",
192
+ torch_dtype=torch.bfloat16,
193
+ trust_remote_code=True,
194
+ use_cache=False, # Disable cache for training
195
+ attn_implementation="flash_attention_2" if torch.cuda.get_device_capability()[0] >= 8 else "eager"
196
+ )
197
+ print("βœ… Using Flash Attention 2 for better performance!")
198
+ except Exception as e:
199
+ print(f"⚠️ Flash Attention 2 not available ({str(e)}), using standard attention")
200
+ # Fallback to standard attention
201
+ model = AutoModelForCausalLM.from_pretrained(
202
+ model_path,
203
+ quantization_config=bnb_config,
204
+ device_map="auto",
205
+ torch_dtype=torch.bfloat16,
206
+ trust_remote_code=True,
207
+ use_cache=False, # Disable cache for training
208
+ )
209
+
210
+ # Prepare model for k-bit training
211
+ model = prepare_model_for_kbit_training(model)
212
+
213
+ return model, tokenizer
214
+
215
+ def setup_lora_config():
216
+ """Setup LoRA configuration for Llama 3.2"""
217
+ lora_config = LoraConfig(
218
+ task_type=TaskType.CAUSAL_LM,
219
+ r=16, # Rank - can be increased for potentially better results
220
+ lora_alpha=32, # LoRA scaling parameter
221
+ lora_dropout=0.1, # LoRA dropout
222
+ target_modules=[
223
+ "q_proj",
224
+ "k_proj",
225
+ "v_proj",
226
+ "o_proj",
227
+ "gate_proj",
228
+ "up_proj",
229
+ "down_proj"
230
+ ],
231
+ bias="none",
232
+ inference_mode=False,
233
+ )
234
+ return lora_config
235
+
236
+ def main():
237
+ # Configuration
238
+ MODEL_PATH = "llama-3.2-3b" # Path to your base model directory
239
+ QA_DATA_PATH = "./new_qa_pairs/" # Path to your JSONL data directory
240
+ OUTPUT_DIR = "llama-3.2-3b-finetuned" # Output directory for the fine-tuned model
241
+
242
+ # Check CUDA availability
243
+ if not torch.cuda.is_available():
244
+ print("❌ CUDA is not available. Please check your installation.")
245
+ return
246
+
247
+ print(f"πŸš€ Starting Llama 3.2 Fine-tuning")
248
+ print(f"Using GPU: {torch.cuda.get_device_name()}")
249
+ print(f"CUDA Version: {torch.version.cuda}")
250
+ print(f"PyTorch Version: {torch.__version__}")
251
+
252
+ # Check if data directory exists
253
+ if not os.path.exists(QA_DATA_PATH):
254
+ print(f"❌ Data directory not found: {QA_DATA_PATH}")
255
+ print("Please create the directory and add your JSONL files.")
256
+ return
257
+
258
+ # Load conversation data
259
+ print(f"\nπŸ“š Loading conversation data from {QA_DATA_PATH}...")
260
+ conversations = load_jsonl_data(QA_DATA_PATH)
261
+
262
+ if len(conversations) == 0:
263
+ print("❌ No valid conversations found. Please check your JSONL files.")
264
+ return
265
+
266
+ # Setup model and tokenizer
267
+ print(f"\n🧠 Loading model and tokenizer from {MODEL_PATH}...")
268
+ model, tokenizer = setup_model_and_tokenizer(MODEL_PATH)
269
+
270
+ # Prepare dataset
271
+ print(f"\nπŸ”§ Preparing dataset...")
272
+ dataset = prepare_dataset(conversations, tokenizer, max_length=1024) # Increased for system messages
273
+
274
+ # Split dataset (90% train, 10% eval)
275
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
276
+ train_dataset = dataset['train']
277
+ eval_dataset = dataset['test']
278
+
279
+ print(f"\nπŸ“Š Dataset Statistics:")
280
+ print(f" β€’ Total conversations: {len(conversations)}")
281
+ print(f" β€’ Training samples: {len(train_dataset)}")
282
+ print(f" β€’ Evaluation samples: {len(eval_dataset)}")
283
+
284
+ # Setup LoRA
285
+ print(f"\n🎯 Setting up LoRA...")
286
+ lora_config = setup_lora_config()
287
+ model = get_peft_model(model, lora_config)
288
+ model.print_trainable_parameters()
289
+
290
+ # Data collator - handles dynamic padding and label preparation
291
+ data_collator = DataCollatorForLanguageModeling(
292
+ tokenizer=tokenizer,
293
+ mlm=False, # We're doing causal language modeling, not masked LM
294
+ pad_to_multiple_of=8,
295
+ return_tensors="pt"
296
+ )
297
+
298
+ # Training arguments - updated for latest API
299
+ training_args = TrainingArguments(
300
+ output_dir=OUTPUT_DIR,
301
+ num_train_epochs=3,
302
+ per_device_train_batch_size=1, # Small batch size for 8GB GPU
303
+ per_device_eval_batch_size=1,
304
+ gradient_accumulation_steps=8, # Effective batch size = 1 * 8 = 8
305
+ warmup_steps=100,
306
+ learning_rate=2e-4,
307
+ weight_decay=0.01,
308
+ fp16=False,
309
+ bf16=True, # Use bfloat16 for better stability
310
+ logging_steps=10,
311
+ eval_steps=100,
312
+ save_steps=200,
313
+ eval_strategy="steps", # Updated parameter name
314
+ save_strategy="steps",
315
+ load_best_model_at_end=True,
316
+ metric_for_best_model="eval_loss",
317
+ greater_is_better=False,
318
+ report_to=None, # Disable wandb/tensorboard logging
319
+ dataloader_pin_memory=True,
320
+ remove_unused_columns=False,
321
+ optim="paged_adamw_8bit", # Memory-efficient optimizer
322
+ lr_scheduler_type="cosine",
323
+ max_grad_norm=1.0,
324
+ dataloader_num_workers=0, # Avoid multiprocessing issues
325
+ group_by_length=False, # Disable grouping for stability
326
+ ddp_find_unused_parameters=False, # For better performance
327
+ save_total_limit=3, # Keep only 3 checkpoints
328
+ prediction_loss_only=False,
329
+ include_inputs_for_metrics=False,
330
+ seed=42,
331
+ data_seed=42,
332
+ # New parameters in latest version
333
+ eval_do_concat_batches=False, # Better for memory
334
+ torch_empty_cache_steps=50, # Clear cache every 50 steps
335
+ gradient_checkpointing=True, # Enable gradient checkpointing for memory efficiency
336
+ gradient_checkpointing_kwargs={"use_reentrant": False}, # Use non-reentrant checkpointing (recommended)
337
+ )
338
+
339
+ # Initialize trainer
340
+ print(f"\nπŸƒ Initializing trainer...")
341
+ trainer = Trainer(
342
+ model=model,
343
+ args=training_args,
344
+ train_dataset=train_dataset,
345
+ eval_dataset=eval_dataset,
346
+ processing_class=tokenizer, # Updated parameter name from tokenizer
347
+ data_collator=data_collator,
348
+ )
349
+
350
+ # Print training info
351
+ total_steps = len(train_dataset) // training_args.gradient_accumulation_steps * training_args.num_train_epochs
352
+ print(f"\nπŸ“ˆ Training Configuration:")
353
+ print(f" β€’ Total training steps: {total_steps}")
354
+ print(f" β€’ Warmup steps: {training_args.warmup_steps}")
355
+ print(f" β€’ Learning rate: {training_args.learning_rate}")
356
+ print(f" β€’ Batch size (effective): {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
357
+ print(f" β€’ Save every: {training_args.save_steps} steps")
358
+ print(f" β€’ Eval every: {training_args.eval_steps} steps")
359
+
360
+ # Start training
361
+ print(f"\nπŸš€ Starting training...")
362
+ print("=" * 60)
363
+ trainer.train()
364
+
365
+ # Save the fine-tuned model
366
+ print(f"\nπŸ’Ύ Saving model...")
367
+ trainer.save_model()
368
+
369
+ # Save tokenizer separately to ensure compatibility
370
+ tokenizer.save_pretrained(OUTPUT_DIR)
371
+
372
+ print(f"\nβœ… Fine-tuning completed!")
373
+ print(f"πŸ“ Model saved to: {OUTPUT_DIR}")
374
+
375
+ # Test the model with a sample conversation
376
+ print(f"\nπŸ§ͺ Testing the model with a sample...")
377
+
378
+ # Set model to eval mode
379
+ model.eval()
380
+
381
+ # Use first conversation as test
382
+ if conversations:
383
+ test_conversation = conversations[0]
384
+
385
+ # Extract system message and user question
386
+ system_msg = next((msg['content'] for msg in test_conversation if msg['role'] == 'system'), "")
387
+ user_msg = next((msg['content'] for msg in test_conversation if msg['role'] == 'user'), "")
388
+ expected_response = next((msg['content'] for msg in test_conversation if msg['role'] == 'assistant'), "")
389
+
390
+ if system_msg and user_msg:
391
+ # Format input for testing
392
+ test_input = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_msg}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
393
+
394
+ # Tokenize and generate
395
+ inputs = tokenizer(test_input, return_tensors="pt").to(model.device)
396
+
397
+ with torch.no_grad():
398
+ outputs = model.generate(
399
+ **inputs,
400
+ max_new_tokens=150,
401
+ temperature=0.7,
402
+ do_sample=True,
403
+ pad_token_id=tokenizer.eos_token_id,
404
+ eos_token_id=tokenizer.eos_token_id,
405
+ repetition_penalty=1.1,
406
+ )
407
+
408
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
409
+ generated_answer = response[len(test_input):].strip()
410
+
411
+ print(f"\nπŸ“‹ Test Results:")
412
+ print(f"System: {system_msg[:100]}{'...' if len(system_msg) > 100 else ''}")
413
+ print(f"Question: {user_msg}")
414
+ print(f"Generated: {generated_answer}")
415
+ print(f"Expected: {expected_response}")
416
+ print("=" * 60)
417
+
418
+ if __name__ == "__main__":
419
+ main()