Spaces:
Build error
Build error
Enhance serve.py with fine-tuning job management, including job creation, status tracking, and training process in a separate thread. Update serve_test.py to include a test for fine-tuning functionality. Modify .gitignore to exclude model files. This update improves model training capabilities and API integration.
Browse files- .gitignore +1 -0
- serve.py +254 -9
- serve_test.py +64 -0
.gitignore
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
logs
|
3 |
lora_model
|
4 |
memory_snapshot.pickle
|
|
|
5 |
outputs
|
6 |
__pycache__
|
7 |
.pytest_cache
|
|
|
2 |
logs
|
3 |
lora_model
|
4 |
memory_snapshot.pickle
|
5 |
+
models
|
6 |
outputs
|
7 |
__pycache__
|
8 |
.pytest_cache
|
serve.py
CHANGED
@@ -17,7 +17,21 @@ from unsloth.chat_templates import get_chat_template # noqa: E402
|
|
17 |
|
18 |
# isort: on
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from openai.types.chat.chat_completion import ChatCompletion
|
22 |
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
23 |
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
@@ -25,12 +39,24 @@ from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChun
|
|
25 |
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
26 |
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
27 |
from openai.types.chat.completion_create_params import CompletionCreateParams
|
|
|
|
|
28 |
from pydantic import TypeAdapter
|
29 |
from ray import serve
|
|
|
|
|
30 |
from sse_starlette import EventSourceResponse
|
31 |
from starlette.responses import JSONResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
from transformers.generation.streamers import AsyncTextIteratorStreamer
|
33 |
from transformers.image_utils import load_image
|
|
|
34 |
|
35 |
dtype = (
|
36 |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
@@ -92,21 +118,30 @@ class ModelDeployment:
|
|
92 |
model_name: str,
|
93 |
):
|
94 |
self.model_name = model_name
|
|
|
|
|
95 |
|
96 |
-
model
|
|
|
97 |
load_in_4bit=load_in_4bit,
|
98 |
max_seq_length=max_seq_length,
|
99 |
model_name=self.model_name,
|
100 |
)
|
101 |
|
102 |
-
#
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
self.model
|
109 |
-
self.processor = processor
|
110 |
|
111 |
def reconfigure(self, config: Dict[str, Any]):
|
112 |
print("=== reconfigure ===")
|
@@ -114,6 +149,216 @@ class ModelDeployment:
|
|
114 |
print(config)
|
115 |
# https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
@app.post("/v1/chat/completions")
|
118 |
async def create_chat_completion(self, body: dict, raw_request: Request):
|
119 |
"""Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning).
|
|
|
17 |
|
18 |
# isort: on
|
19 |
|
20 |
+
import asyncio
|
21 |
+
import json
|
22 |
+
import threading
|
23 |
+
import uuid
|
24 |
+
from datetime import datetime
|
25 |
+
from typing import Dict, List, Optional
|
26 |
+
|
27 |
+
from datasets import (
|
28 |
+
Dataset,
|
29 |
+
DatasetDict,
|
30 |
+
IterableDataset,
|
31 |
+
IterableDatasetDict,
|
32 |
+
load_dataset,
|
33 |
+
)
|
34 |
+
from fastapi import FastAPI, HTTPException, Request
|
35 |
from openai.types.chat.chat_completion import ChatCompletion
|
36 |
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
37 |
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
39 |
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
40 |
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
41 |
from openai.types.chat.completion_create_params import CompletionCreateParams
|
42 |
+
from openai.types.fine_tuning import FineTuningJob
|
43 |
+
from peft import PeftModel
|
44 |
from pydantic import TypeAdapter
|
45 |
from ray import serve
|
46 |
+
from smolagents import CodeAgent, LiteLLMModel, Model, TransformersModel, VLLMModel
|
47 |
+
from smolagents.monitoring import LogLevel
|
48 |
from sse_starlette import EventSourceResponse
|
49 |
from starlette.responses import JSONResponse
|
50 |
+
from transformers import (
|
51 |
+
AutoModelForCausalLM,
|
52 |
+
AutoTokenizer,
|
53 |
+
DataCollatorForLanguageModeling,
|
54 |
+
Trainer,
|
55 |
+
TrainingArguments,
|
56 |
+
)
|
57 |
from transformers.generation.streamers import AsyncTextIteratorStreamer
|
58 |
from transformers.image_utils import load_image
|
59 |
+
from trl import SFTTrainer
|
60 |
|
61 |
dtype = (
|
62 |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
|
|
118 |
model_name: str,
|
119 |
):
|
120 |
self.model_name = model_name
|
121 |
+
self.fine_tuning_jobs: Dict[str, FineTuningJob] = {}
|
122 |
+
self.training_threads: Dict[str, threading.Thread] = {}
|
123 |
|
124 |
+
# Load base model and processor
|
125 |
+
self.model, self.processor = FastModel.from_pretrained(
|
126 |
load_in_4bit=load_in_4bit,
|
127 |
max_seq_length=max_seq_length,
|
128 |
model_name=self.model_name,
|
129 |
)
|
130 |
|
131 |
+
# Configure LoRA for fine-tuning
|
132 |
+
self.model = FastModel.get_peft_model(
|
133 |
+
self.model,
|
134 |
+
r=16, # LoRA rank
|
135 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
136 |
+
lora_alpha=32,
|
137 |
+
lora_dropout=0.05,
|
138 |
+
bias="none",
|
139 |
+
use_gradient_checkpointing=True,
|
140 |
+
random_state=42,
|
141 |
+
use_rslora=False,
|
142 |
+
)
|
143 |
|
144 |
+
FastModel.for_inference(self.model) # Enable native 2x faster inference
|
|
|
145 |
|
146 |
def reconfigure(self, config: Dict[str, Any]):
|
147 |
print("=== reconfigure ===")
|
|
|
149 |
print(config)
|
150 |
# https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config
|
151 |
|
152 |
+
def _run_training(self, job_id: str, training_file: str, model_name: str):
|
153 |
+
"""Run the training process in a separate thread."""
|
154 |
+
try:
|
155 |
+
# Update job status to queued
|
156 |
+
self.fine_tuning_jobs[job_id].status = "queued"
|
157 |
+
|
158 |
+
# Simulate file validation
|
159 |
+
time.sleep(2)
|
160 |
+
|
161 |
+
# Update job status to running
|
162 |
+
self.fine_tuning_jobs[job_id].status = "running"
|
163 |
+
self.fine_tuning_jobs[job_id].started_at = int(datetime.now().timestamp())
|
164 |
+
|
165 |
+
# Load and prepare dataset
|
166 |
+
dataset = load_dataset("json", data_files=training_file)
|
167 |
+
|
168 |
+
# Configure chat template
|
169 |
+
tokenizer = get_chat_template(
|
170 |
+
self.processor,
|
171 |
+
chat_template="chatml",
|
172 |
+
mapping={
|
173 |
+
"role": "from",
|
174 |
+
"content": "value",
|
175 |
+
"user": "human",
|
176 |
+
"assistant": "gpt",
|
177 |
+
},
|
178 |
+
map_eos_token=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
# Format dataset
|
182 |
+
def formatting_prompts_func(examples):
|
183 |
+
convos = examples["conversations"]
|
184 |
+
texts = [
|
185 |
+
tokenizer.apply_chat_template(
|
186 |
+
convo, tokenize=False, add_generation_prompt=False
|
187 |
+
)
|
188 |
+
for convo in convos
|
189 |
+
]
|
190 |
+
return {"text": texts}
|
191 |
+
|
192 |
+
dataset = dataset.map(formatting_prompts_func, batched=True)
|
193 |
+
|
194 |
+
# Configure training arguments
|
195 |
+
training_args = TrainingArguments(
|
196 |
+
output_dir=f"models/{job_id}",
|
197 |
+
num_train_epochs=3,
|
198 |
+
per_device_train_batch_size=4,
|
199 |
+
gradient_accumulation_steps=4,
|
200 |
+
learning_rate=2e-4,
|
201 |
+
fp16=True,
|
202 |
+
logging_steps=10,
|
203 |
+
save_strategy="epoch",
|
204 |
+
optim="adamw_torch",
|
205 |
+
warmup_ratio=0.1,
|
206 |
+
lr_scheduler_type="cosine",
|
207 |
+
weight_decay=0.01,
|
208 |
+
)
|
209 |
+
|
210 |
+
# Create data collator
|
211 |
+
data_collator = DataCollatorForLanguageModeling(
|
212 |
+
tokenizer=tokenizer,
|
213 |
+
mlm=False,
|
214 |
+
)
|
215 |
+
|
216 |
+
# Create trainer
|
217 |
+
trainer = SFTTrainer(
|
218 |
+
model=self.model,
|
219 |
+
tokenizer=tokenizer,
|
220 |
+
train_dataset=dataset["train"],
|
221 |
+
args=training_args,
|
222 |
+
data_collator=data_collator,
|
223 |
+
max_seq_length=max_seq_length,
|
224 |
+
packing=False,
|
225 |
+
)
|
226 |
+
|
227 |
+
# Train
|
228 |
+
trainer.train()
|
229 |
+
|
230 |
+
# Save model and adapter
|
231 |
+
output_dir = f"models/{job_id}"
|
232 |
+
os.makedirs(output_dir, exist_ok=True)
|
233 |
+
|
234 |
+
# Save the base model config and tokenizer
|
235 |
+
self.model.config.save_pretrained(output_dir)
|
236 |
+
tokenizer.save_pretrained(output_dir)
|
237 |
+
|
238 |
+
# Save the adapter weights
|
239 |
+
self.model.save_pretrained(output_dir)
|
240 |
+
|
241 |
+
# Save the merged model in 16-bit format
|
242 |
+
try:
|
243 |
+
# First try to merge and save in 16-bit
|
244 |
+
self.model.save_pretrained_merged(
|
245 |
+
output_dir,
|
246 |
+
tokenizer,
|
247 |
+
save_method="merged_16bit",
|
248 |
+
)
|
249 |
+
except Exception as merge_error:
|
250 |
+
print(f"Failed to merge weights: {str(merge_error)}")
|
251 |
+
# If merging fails, just save the adapter weights
|
252 |
+
self.model.save_pretrained(output_dir)
|
253 |
+
|
254 |
+
# Update job status to succeeded
|
255 |
+
self.fine_tuning_jobs[job_id].status = "succeeded"
|
256 |
+
self.fine_tuning_jobs[job_id].finished_at = int(datetime.now().timestamp())
|
257 |
+
self.fine_tuning_jobs[job_id].trained_tokens = (
|
258 |
+
trainer.state.global_step * training_args.per_device_train_batch_size
|
259 |
+
)
|
260 |
+
|
261 |
+
# Add result files
|
262 |
+
result_files = [
|
263 |
+
f"{output_dir}/config.json",
|
264 |
+
f"{output_dir}/tokenizer.json",
|
265 |
+
f"{output_dir}/adapter_config.json",
|
266 |
+
f"{output_dir}/adapter_model.bin",
|
267 |
+
]
|
268 |
+
|
269 |
+
# Add merged model files if they exist
|
270 |
+
if os.path.exists(f"{output_dir}/pytorch_model.bin"):
|
271 |
+
result_files.append(f"{output_dir}/pytorch_model.bin")
|
272 |
+
|
273 |
+
self.fine_tuning_jobs[job_id].result_files = result_files
|
274 |
+
|
275 |
+
except Exception as e:
|
276 |
+
# Update job status to failed
|
277 |
+
self.fine_tuning_jobs[job_id].status = "failed"
|
278 |
+
self.fine_tuning_jobs[job_id].finished_at = int(datetime.now().timestamp())
|
279 |
+
self.fine_tuning_jobs[job_id].error = str(e)
|
280 |
+
print(f"Training failed: {str(e)}")
|
281 |
+
import traceback
|
282 |
+
|
283 |
+
print(traceback.format_exc())
|
284 |
+
|
285 |
+
@app.post("/v1/fine_tuning/jobs")
|
286 |
+
async def create_fine_tuning_job(self, body: dict):
|
287 |
+
"""Create a fine-tuning job."""
|
288 |
+
try:
|
289 |
+
# Validate required fields
|
290 |
+
if "training_file" not in body:
|
291 |
+
raise HTTPException(status_code=400, detail="training_file is required")
|
292 |
+
if "model" not in body:
|
293 |
+
raise HTTPException(status_code=400, detail="model is required")
|
294 |
+
|
295 |
+
# Generate job ID
|
296 |
+
job_id = f"ftjob-{uuid.uuid4().hex[:8]}"
|
297 |
+
|
298 |
+
# Create job object
|
299 |
+
job = FineTuningJob(
|
300 |
+
id=job_id,
|
301 |
+
object="fine_tuning.job",
|
302 |
+
created_at=int(datetime.now().timestamp()),
|
303 |
+
finished_at=None,
|
304 |
+
model=body["model"],
|
305 |
+
fine_tuned_model=None,
|
306 |
+
organization_id="org-123",
|
307 |
+
status="validating_files", # Start with validating_files
|
308 |
+
hyperparameters=body.get("hyperparameters", {}),
|
309 |
+
training_file=body["training_file"],
|
310 |
+
trained_tokens=None,
|
311 |
+
error=None,
|
312 |
+
result_files=[], # Required field
|
313 |
+
seed=42, # Required field
|
314 |
+
)
|
315 |
+
|
316 |
+
# Store job
|
317 |
+
self.fine_tuning_jobs[job_id] = job
|
318 |
+
|
319 |
+
# Start training in background thread
|
320 |
+
thread = threading.Thread(
|
321 |
+
target=self._run_training,
|
322 |
+
args=(job_id, body["training_file"], body["model"]),
|
323 |
+
)
|
324 |
+
thread.start()
|
325 |
+
self.training_threads[job_id] = thread
|
326 |
+
|
327 |
+
return job.model_dump()
|
328 |
+
|
329 |
+
except Exception as e:
|
330 |
+
raise HTTPException(status_code=500, detail=str(e))
|
331 |
+
|
332 |
+
@app.get("/v1/fine_tuning/jobs")
|
333 |
+
async def list_fine_tuning_jobs(self):
|
334 |
+
"""List all fine-tuning jobs."""
|
335 |
+
return {
|
336 |
+
"object": "list",
|
337 |
+
"data": [job.model_dump() for job in self.fine_tuning_jobs.values()],
|
338 |
+
}
|
339 |
+
|
340 |
+
@app.get("/v1/fine_tuning/jobs/{job_id}")
|
341 |
+
async def get_fine_tuning_job(self, job_id: str):
|
342 |
+
"""Get a specific fine-tuning job."""
|
343 |
+
if job_id not in self.fine_tuning_jobs:
|
344 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
345 |
+
return self.fine_tuning_jobs[job_id].model_dump()
|
346 |
+
|
347 |
+
@app.post("/v1/fine_tuning/jobs/{job_id}/cancel")
|
348 |
+
async def cancel_fine_tuning_job(self, job_id: str):
|
349 |
+
"""Cancel a fine-tuning job."""
|
350 |
+
if job_id not in self.fine_tuning_jobs:
|
351 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
352 |
+
|
353 |
+
job = self.fine_tuning_jobs[job_id]
|
354 |
+
if job.status not in ["created", "running"]:
|
355 |
+
raise HTTPException(status_code=400, detail="Job cannot be cancelled")
|
356 |
+
|
357 |
+
job.status = "cancelled"
|
358 |
+
job.finished_at = int(datetime.now().timestamp())
|
359 |
+
|
360 |
+
return job.model_dump()
|
361 |
+
|
362 |
@app.post("/v1/chat/completions")
|
363 |
async def create_chat_completion(self, body: dict, raw_request: Request):
|
364 |
"""Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning).
|
serve_test.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import json
|
|
|
|
|
2 |
|
3 |
from openai import OpenAI
|
4 |
|
@@ -35,6 +37,68 @@ def test_chat_completion():
|
|
35 |
print(traceback.format_exc())
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
if __name__ == "__main__":
|
39 |
print("Testing chat completions endpoint...")
|
40 |
test_chat_completion()
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
+
import time
|
4 |
|
5 |
from openai import OpenAI
|
6 |
|
|
|
37 |
print(traceback.format_exc())
|
38 |
|
39 |
|
40 |
+
def test_fine_tuning():
|
41 |
+
try:
|
42 |
+
# Create a sample training file
|
43 |
+
training_data = {
|
44 |
+
"conversations": [
|
45 |
+
{
|
46 |
+
"from": "human",
|
47 |
+
"value": "What is the capital of France?",
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"from": "gpt",
|
51 |
+
"value": "The capital of France is Paris.",
|
52 |
+
},
|
53 |
+
]
|
54 |
+
}
|
55 |
+
|
56 |
+
training_file = "training_data.json"
|
57 |
+
with open(training_file, "w") as f:
|
58 |
+
json.dump(training_data, f)
|
59 |
+
|
60 |
+
print("\nCreating fine-tuning job...")
|
61 |
+
job = client.fine_tuning.jobs.create(
|
62 |
+
training_file=training_file,
|
63 |
+
model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
|
64 |
+
)
|
65 |
+
print(f"Created job: {job.id}")
|
66 |
+
|
67 |
+
# Wait for job to start
|
68 |
+
print("\nWaiting for job to start...")
|
69 |
+
time.sleep(2)
|
70 |
+
|
71 |
+
# List jobs
|
72 |
+
print("\nListing fine-tuning jobs...")
|
73 |
+
jobs = client.fine_tuning.jobs.list()
|
74 |
+
print(f"Found {len(jobs.data)} jobs")
|
75 |
+
|
76 |
+
# Get job status
|
77 |
+
print("\nGetting job status...")
|
78 |
+
job = client.fine_tuning.jobs.retrieve(job.id)
|
79 |
+
print(f"Job status: {job.status}")
|
80 |
+
|
81 |
+
# Wait for job to complete or fail
|
82 |
+
print("\nWaiting for job to complete...")
|
83 |
+
while job.status in ["created", "running"]:
|
84 |
+
time.sleep(5)
|
85 |
+
job = client.fine_tuning.jobs.retrieve(job.id)
|
86 |
+
print(f"Job status: {job.status}")
|
87 |
+
|
88 |
+
# Clean up
|
89 |
+
os.remove(training_file)
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
print(f"Error occurred: {str(e)}")
|
93 |
+
import traceback
|
94 |
+
|
95 |
+
print("\nFull traceback:")
|
96 |
+
print(traceback.format_exc())
|
97 |
+
|
98 |
+
|
99 |
if __name__ == "__main__":
|
100 |
print("Testing chat completions endpoint...")
|
101 |
test_chat_completion()
|
102 |
+
|
103 |
+
print("\nTesting fine-tuning endpoints...")
|
104 |
+
test_fine_tuning()
|