mjschock commited on
Commit
145385b
·
unverified ·
1 Parent(s): 4395ceb

Enhance serve.py with fine-tuning job management, including job creation, status tracking, and training process in a separate thread. Update serve_test.py to include a test for fine-tuning functionality. Modify .gitignore to exclude model files. This update improves model training capabilities and API integration.

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. serve.py +254 -9
  3. serve_test.py +64 -0
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  logs
3
  lora_model
4
  memory_snapshot.pickle
 
5
  outputs
6
  __pycache__
7
  .pytest_cache
 
2
  logs
3
  lora_model
4
  memory_snapshot.pickle
5
+ models
6
  outputs
7
  __pycache__
8
  .pytest_cache
serve.py CHANGED
@@ -17,7 +17,21 @@ from unsloth.chat_templates import get_chat_template # noqa: E402
17
 
18
  # isort: on
19
 
20
- from fastapi import FastAPI, Request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from openai.types.chat.chat_completion import ChatCompletion
22
  from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
23
  from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@@ -25,12 +39,24 @@ from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChun
25
  from openai.types.chat.chat_completion_chunk import ChoiceDelta
26
  from openai.types.chat.chat_completion_message import ChatCompletionMessage
27
  from openai.types.chat.completion_create_params import CompletionCreateParams
 
 
28
  from pydantic import TypeAdapter
29
  from ray import serve
 
 
30
  from sse_starlette import EventSourceResponse
31
  from starlette.responses import JSONResponse
 
 
 
 
 
 
 
32
  from transformers.generation.streamers import AsyncTextIteratorStreamer
33
  from transformers.image_utils import load_image
 
34
 
35
  dtype = (
36
  None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
@@ -92,21 +118,30 @@ class ModelDeployment:
92
  model_name: str,
93
  ):
94
  self.model_name = model_name
 
 
95
 
96
- model, processor = FastModel.from_pretrained(
 
97
  load_in_4bit=load_in_4bit,
98
  max_seq_length=max_seq_length,
99
  model_name=self.model_name,
100
  )
101
 
102
- # with open("chat_template.txt", "r") as f:
103
- # processor.chat_template = f.read()
104
- # processor.tokenizer.chat_template = processor.chat_template
105
-
106
- FastModel.for_inference(model) # Enable native 2x faster inference
 
 
 
 
 
 
 
107
 
108
- self.model = model
109
- self.processor = processor
110
 
111
  def reconfigure(self, config: Dict[str, Any]):
112
  print("=== reconfigure ===")
@@ -114,6 +149,216 @@ class ModelDeployment:
114
  print(config)
115
  # https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @app.post("/v1/chat/completions")
118
  async def create_chat_completion(self, body: dict, raw_request: Request):
119
  """Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning).
 
17
 
18
  # isort: on
19
 
20
+ import asyncio
21
+ import json
22
+ import threading
23
+ import uuid
24
+ from datetime import datetime
25
+ from typing import Dict, List, Optional
26
+
27
+ from datasets import (
28
+ Dataset,
29
+ DatasetDict,
30
+ IterableDataset,
31
+ IterableDatasetDict,
32
+ load_dataset,
33
+ )
34
+ from fastapi import FastAPI, HTTPException, Request
35
  from openai.types.chat.chat_completion import ChatCompletion
36
  from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
37
  from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
 
39
  from openai.types.chat.chat_completion_chunk import ChoiceDelta
40
  from openai.types.chat.chat_completion_message import ChatCompletionMessage
41
  from openai.types.chat.completion_create_params import CompletionCreateParams
42
+ from openai.types.fine_tuning import FineTuningJob
43
+ from peft import PeftModel
44
  from pydantic import TypeAdapter
45
  from ray import serve
46
+ from smolagents import CodeAgent, LiteLLMModel, Model, TransformersModel, VLLMModel
47
+ from smolagents.monitoring import LogLevel
48
  from sse_starlette import EventSourceResponse
49
  from starlette.responses import JSONResponse
50
+ from transformers import (
51
+ AutoModelForCausalLM,
52
+ AutoTokenizer,
53
+ DataCollatorForLanguageModeling,
54
+ Trainer,
55
+ TrainingArguments,
56
+ )
57
  from transformers.generation.streamers import AsyncTextIteratorStreamer
58
  from transformers.image_utils import load_image
59
+ from trl import SFTTrainer
60
 
61
  dtype = (
62
  None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
 
118
  model_name: str,
119
  ):
120
  self.model_name = model_name
121
+ self.fine_tuning_jobs: Dict[str, FineTuningJob] = {}
122
+ self.training_threads: Dict[str, threading.Thread] = {}
123
 
124
+ # Load base model and processor
125
+ self.model, self.processor = FastModel.from_pretrained(
126
  load_in_4bit=load_in_4bit,
127
  max_seq_length=max_seq_length,
128
  model_name=self.model_name,
129
  )
130
 
131
+ # Configure LoRA for fine-tuning
132
+ self.model = FastModel.get_peft_model(
133
+ self.model,
134
+ r=16, # LoRA rank
135
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
136
+ lora_alpha=32,
137
+ lora_dropout=0.05,
138
+ bias="none",
139
+ use_gradient_checkpointing=True,
140
+ random_state=42,
141
+ use_rslora=False,
142
+ )
143
 
144
+ FastModel.for_inference(self.model) # Enable native 2x faster inference
 
145
 
146
  def reconfigure(self, config: Dict[str, Any]):
147
  print("=== reconfigure ===")
 
149
  print(config)
150
  # https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config
151
 
152
+ def _run_training(self, job_id: str, training_file: str, model_name: str):
153
+ """Run the training process in a separate thread."""
154
+ try:
155
+ # Update job status to queued
156
+ self.fine_tuning_jobs[job_id].status = "queued"
157
+
158
+ # Simulate file validation
159
+ time.sleep(2)
160
+
161
+ # Update job status to running
162
+ self.fine_tuning_jobs[job_id].status = "running"
163
+ self.fine_tuning_jobs[job_id].started_at = int(datetime.now().timestamp())
164
+
165
+ # Load and prepare dataset
166
+ dataset = load_dataset("json", data_files=training_file)
167
+
168
+ # Configure chat template
169
+ tokenizer = get_chat_template(
170
+ self.processor,
171
+ chat_template="chatml",
172
+ mapping={
173
+ "role": "from",
174
+ "content": "value",
175
+ "user": "human",
176
+ "assistant": "gpt",
177
+ },
178
+ map_eos_token=True,
179
+ )
180
+
181
+ # Format dataset
182
+ def formatting_prompts_func(examples):
183
+ convos = examples["conversations"]
184
+ texts = [
185
+ tokenizer.apply_chat_template(
186
+ convo, tokenize=False, add_generation_prompt=False
187
+ )
188
+ for convo in convos
189
+ ]
190
+ return {"text": texts}
191
+
192
+ dataset = dataset.map(formatting_prompts_func, batched=True)
193
+
194
+ # Configure training arguments
195
+ training_args = TrainingArguments(
196
+ output_dir=f"models/{job_id}",
197
+ num_train_epochs=3,
198
+ per_device_train_batch_size=4,
199
+ gradient_accumulation_steps=4,
200
+ learning_rate=2e-4,
201
+ fp16=True,
202
+ logging_steps=10,
203
+ save_strategy="epoch",
204
+ optim="adamw_torch",
205
+ warmup_ratio=0.1,
206
+ lr_scheduler_type="cosine",
207
+ weight_decay=0.01,
208
+ )
209
+
210
+ # Create data collator
211
+ data_collator = DataCollatorForLanguageModeling(
212
+ tokenizer=tokenizer,
213
+ mlm=False,
214
+ )
215
+
216
+ # Create trainer
217
+ trainer = SFTTrainer(
218
+ model=self.model,
219
+ tokenizer=tokenizer,
220
+ train_dataset=dataset["train"],
221
+ args=training_args,
222
+ data_collator=data_collator,
223
+ max_seq_length=max_seq_length,
224
+ packing=False,
225
+ )
226
+
227
+ # Train
228
+ trainer.train()
229
+
230
+ # Save model and adapter
231
+ output_dir = f"models/{job_id}"
232
+ os.makedirs(output_dir, exist_ok=True)
233
+
234
+ # Save the base model config and tokenizer
235
+ self.model.config.save_pretrained(output_dir)
236
+ tokenizer.save_pretrained(output_dir)
237
+
238
+ # Save the adapter weights
239
+ self.model.save_pretrained(output_dir)
240
+
241
+ # Save the merged model in 16-bit format
242
+ try:
243
+ # First try to merge and save in 16-bit
244
+ self.model.save_pretrained_merged(
245
+ output_dir,
246
+ tokenizer,
247
+ save_method="merged_16bit",
248
+ )
249
+ except Exception as merge_error:
250
+ print(f"Failed to merge weights: {str(merge_error)}")
251
+ # If merging fails, just save the adapter weights
252
+ self.model.save_pretrained(output_dir)
253
+
254
+ # Update job status to succeeded
255
+ self.fine_tuning_jobs[job_id].status = "succeeded"
256
+ self.fine_tuning_jobs[job_id].finished_at = int(datetime.now().timestamp())
257
+ self.fine_tuning_jobs[job_id].trained_tokens = (
258
+ trainer.state.global_step * training_args.per_device_train_batch_size
259
+ )
260
+
261
+ # Add result files
262
+ result_files = [
263
+ f"{output_dir}/config.json",
264
+ f"{output_dir}/tokenizer.json",
265
+ f"{output_dir}/adapter_config.json",
266
+ f"{output_dir}/adapter_model.bin",
267
+ ]
268
+
269
+ # Add merged model files if they exist
270
+ if os.path.exists(f"{output_dir}/pytorch_model.bin"):
271
+ result_files.append(f"{output_dir}/pytorch_model.bin")
272
+
273
+ self.fine_tuning_jobs[job_id].result_files = result_files
274
+
275
+ except Exception as e:
276
+ # Update job status to failed
277
+ self.fine_tuning_jobs[job_id].status = "failed"
278
+ self.fine_tuning_jobs[job_id].finished_at = int(datetime.now().timestamp())
279
+ self.fine_tuning_jobs[job_id].error = str(e)
280
+ print(f"Training failed: {str(e)}")
281
+ import traceback
282
+
283
+ print(traceback.format_exc())
284
+
285
+ @app.post("/v1/fine_tuning/jobs")
286
+ async def create_fine_tuning_job(self, body: dict):
287
+ """Create a fine-tuning job."""
288
+ try:
289
+ # Validate required fields
290
+ if "training_file" not in body:
291
+ raise HTTPException(status_code=400, detail="training_file is required")
292
+ if "model" not in body:
293
+ raise HTTPException(status_code=400, detail="model is required")
294
+
295
+ # Generate job ID
296
+ job_id = f"ftjob-{uuid.uuid4().hex[:8]}"
297
+
298
+ # Create job object
299
+ job = FineTuningJob(
300
+ id=job_id,
301
+ object="fine_tuning.job",
302
+ created_at=int(datetime.now().timestamp()),
303
+ finished_at=None,
304
+ model=body["model"],
305
+ fine_tuned_model=None,
306
+ organization_id="org-123",
307
+ status="validating_files", # Start with validating_files
308
+ hyperparameters=body.get("hyperparameters", {}),
309
+ training_file=body["training_file"],
310
+ trained_tokens=None,
311
+ error=None,
312
+ result_files=[], # Required field
313
+ seed=42, # Required field
314
+ )
315
+
316
+ # Store job
317
+ self.fine_tuning_jobs[job_id] = job
318
+
319
+ # Start training in background thread
320
+ thread = threading.Thread(
321
+ target=self._run_training,
322
+ args=(job_id, body["training_file"], body["model"]),
323
+ )
324
+ thread.start()
325
+ self.training_threads[job_id] = thread
326
+
327
+ return job.model_dump()
328
+
329
+ except Exception as e:
330
+ raise HTTPException(status_code=500, detail=str(e))
331
+
332
+ @app.get("/v1/fine_tuning/jobs")
333
+ async def list_fine_tuning_jobs(self):
334
+ """List all fine-tuning jobs."""
335
+ return {
336
+ "object": "list",
337
+ "data": [job.model_dump() for job in self.fine_tuning_jobs.values()],
338
+ }
339
+
340
+ @app.get("/v1/fine_tuning/jobs/{job_id}")
341
+ async def get_fine_tuning_job(self, job_id: str):
342
+ """Get a specific fine-tuning job."""
343
+ if job_id not in self.fine_tuning_jobs:
344
+ raise HTTPException(status_code=404, detail="Job not found")
345
+ return self.fine_tuning_jobs[job_id].model_dump()
346
+
347
+ @app.post("/v1/fine_tuning/jobs/{job_id}/cancel")
348
+ async def cancel_fine_tuning_job(self, job_id: str):
349
+ """Cancel a fine-tuning job."""
350
+ if job_id not in self.fine_tuning_jobs:
351
+ raise HTTPException(status_code=404, detail="Job not found")
352
+
353
+ job = self.fine_tuning_jobs[job_id]
354
+ if job.status not in ["created", "running"]:
355
+ raise HTTPException(status_code=400, detail="Job cannot be cancelled")
356
+
357
+ job.status = "cancelled"
358
+ job.finished_at = int(datetime.now().timestamp())
359
+
360
+ return job.model_dump()
361
+
362
  @app.post("/v1/chat/completions")
363
  async def create_chat_completion(self, body: dict, raw_request: Request):
364
  """Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning).
serve_test.py CHANGED
@@ -1,4 +1,6 @@
1
  import json
 
 
2
 
3
  from openai import OpenAI
4
 
@@ -35,6 +37,68 @@ def test_chat_completion():
35
  print(traceback.format_exc())
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  if __name__ == "__main__":
39
  print("Testing chat completions endpoint...")
40
  test_chat_completion()
 
 
 
 
1
  import json
2
+ import os
3
+ import time
4
 
5
  from openai import OpenAI
6
 
 
37
  print(traceback.format_exc())
38
 
39
 
40
+ def test_fine_tuning():
41
+ try:
42
+ # Create a sample training file
43
+ training_data = {
44
+ "conversations": [
45
+ {
46
+ "from": "human",
47
+ "value": "What is the capital of France?",
48
+ },
49
+ {
50
+ "from": "gpt",
51
+ "value": "The capital of France is Paris.",
52
+ },
53
+ ]
54
+ }
55
+
56
+ training_file = "training_data.json"
57
+ with open(training_file, "w") as f:
58
+ json.dump(training_data, f)
59
+
60
+ print("\nCreating fine-tuning job...")
61
+ job = client.fine_tuning.jobs.create(
62
+ training_file=training_file,
63
+ model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
64
+ )
65
+ print(f"Created job: {job.id}")
66
+
67
+ # Wait for job to start
68
+ print("\nWaiting for job to start...")
69
+ time.sleep(2)
70
+
71
+ # List jobs
72
+ print("\nListing fine-tuning jobs...")
73
+ jobs = client.fine_tuning.jobs.list()
74
+ print(f"Found {len(jobs.data)} jobs")
75
+
76
+ # Get job status
77
+ print("\nGetting job status...")
78
+ job = client.fine_tuning.jobs.retrieve(job.id)
79
+ print(f"Job status: {job.status}")
80
+
81
+ # Wait for job to complete or fail
82
+ print("\nWaiting for job to complete...")
83
+ while job.status in ["created", "running"]:
84
+ time.sleep(5)
85
+ job = client.fine_tuning.jobs.retrieve(job.id)
86
+ print(f"Job status: {job.status}")
87
+
88
+ # Clean up
89
+ os.remove(training_file)
90
+
91
+ except Exception as e:
92
+ print(f"Error occurred: {str(e)}")
93
+ import traceback
94
+
95
+ print("\nFull traceback:")
96
+ print(traceback.format_exc())
97
+
98
+
99
  if __name__ == "__main__":
100
  print("Testing chat completions endpoint...")
101
  test_chat_completion()
102
+
103
+ print("\nTesting fine-tuning endpoints...")
104
+ test_fine_tuning()