import asyncio import logging import os import time from pprint import pprint from threading import Thread from typing import Any, Dict, List # isort: off from unsloth import ( FastLanguageModel, FastModel, FastVisionModel, is_bfloat16_supported, ) # noqa: E402 from unsloth.chat_templates import get_chat_template # noqa: E402 # isort: on import asyncio import json import threading import uuid from datetime import datetime from typing import Dict, List, Optional from datasets import ( Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset, ) from fastapi import FastAPI, HTTPException, Request from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import Choice as ChatCompletionChoice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChunkChoice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.completion_create_params import CompletionCreateParams from openai.types.fine_tuning import FineTuningJob from peft import PeftModel from pydantic import TypeAdapter from ray import serve from smolagents import CodeAgent, LiteLLMModel, Model, TransformersModel, VLLMModel from smolagents.monitoring import LogLevel from sse_starlette import EventSourceResponse from starlette.responses import JSONResponse from transformers import ( AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, ) from transformers.generation.streamers import AsyncTextIteratorStreamer from transformers.image_utils import load_image from trl import SFTTrainer dtype = ( None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ ) load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any! # max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally! logger = logging.getLogger("ray.serve") os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" app = FastAPI() # middlewares = [ # middleware # for middleware in ConnexionMiddleware.default_middlewares # if middleware is not SecurityMiddleware # ] # connexion_app = AsyncApp(import_name=__name__, middlewares=middlewares) # connexion_app.add_api( # # "api/openai/v1/openapi/openapi.yaml", # "api/v1/openapi/openapi.yaml", # # base_path="/openai/v1", # base_path="/v1", # pythonic_params=True, # resolver_error=501, # ) # # fastapi_app.mount("/api", ConnexionMiddleware(app=connexion_app, import_name=__name__)) # # app.mount("/api", ConnexionMiddleware(app=connexion_app, import_name=__name__)) # app.mount( # "/", # ConnexionMiddleware( # app=connexion_app, # import_name=__name__, # # middlewares=middlewares, # ), # ) @serve.deployment( autoscaling_config={ # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#required-define-upper-and-lower-autoscaling-limits "max_replicas": 1, "min_replicas": 1, # TOOD: set to 0 "target_ongoing_requests": 2, # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#target-ongoing-requests-default-2 }, max_ongoing_requests=5, # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#max-ongoing-requests-default-5 ray_actor_options={"num_gpus": 1}, ) @serve.ingress(app) class ModelDeployment: def __init__( self, model_name: str, ): self.model_name = model_name self.fine_tuning_jobs: Dict[str, FineTuningJob] = {} self.training_threads: Dict[str, threading.Thread] = {} # Load base model and processor self.model, self.processor = FastModel.from_pretrained( load_in_4bit=load_in_4bit, max_seq_length=max_seq_length, model_name=self.model_name, ) # Configure LoRA for fine-tuning self.model = FastModel.get_peft_model( self.model, r=16, # LoRA rank target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_alpha=32, lora_dropout=0.05, bias="none", use_gradient_checkpointing=True, random_state=42, use_rslora=False, ) FastModel.for_inference(self.model) # Enable native 2x faster inference def reconfigure(self, config: Dict[str, Any]): print("=== reconfigure ===") print("config:") print(config) # https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config def _run_training(self, job_id: str, training_file: str, model_name: str): """Run the training process in a separate thread.""" try: # Update job status to queued self.fine_tuning_jobs[job_id].status = "queued" # Simulate file validation time.sleep(2) # Update job status to running self.fine_tuning_jobs[job_id].status = "running" self.fine_tuning_jobs[job_id].started_at = int(datetime.now().timestamp()) # Load and prepare dataset dataset = load_dataset("json", data_files=training_file) # Configure chat template tokenizer = get_chat_template( self.processor, chat_template="chatml", mapping={ "role": "from", "content": "value", "user": "human", "assistant": "gpt", }, map_eos_token=True, ) # Format dataset def formatting_prompts_func(examples): convos = examples["conversations"] texts = [ tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=False ) for convo in convos ] return {"text": texts} dataset = dataset.map(formatting_prompts_func, batched=True) # Configure training arguments training_args = TrainingArguments( output_dir=f"models/{job_id}", num_train_epochs=3, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, fp16=True, logging_steps=10, save_strategy="epoch", optim="adamw_torch", warmup_ratio=0.1, lr_scheduler_type="cosine", weight_decay=0.01, ) # Create data collator data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) # Create trainer trainer = SFTTrainer( model=self.model, tokenizer=tokenizer, train_dataset=dataset["train"], args=training_args, data_collator=data_collator, max_seq_length=max_seq_length, packing=False, ) # Train trainer.train() # Save model and adapter output_dir = f"models/{job_id}" os.makedirs(output_dir, exist_ok=True) # Save the base model config and tokenizer self.model.config.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # Save the adapter weights self.model.save_pretrained(output_dir) # Save the merged model in 16-bit format try: # First try to merge and save in 16-bit self.model.save_pretrained_merged( output_dir, tokenizer, save_method="merged_16bit", ) except Exception as merge_error: print(f"Failed to merge weights: {str(merge_error)}") # If merging fails, just save the adapter weights self.model.save_pretrained(output_dir) # Update job status to succeeded self.fine_tuning_jobs[job_id].status = "succeeded" self.fine_tuning_jobs[job_id].finished_at = int(datetime.now().timestamp()) self.fine_tuning_jobs[job_id].trained_tokens = ( trainer.state.global_step * training_args.per_device_train_batch_size ) # Add result files result_files = [ f"{output_dir}/config.json", f"{output_dir}/tokenizer.json", f"{output_dir}/adapter_config.json", f"{output_dir}/adapter_model.bin", ] # Add merged model files if they exist if os.path.exists(f"{output_dir}/pytorch_model.bin"): result_files.append(f"{output_dir}/pytorch_model.bin") self.fine_tuning_jobs[job_id].result_files = result_files except Exception as e: # Update job status to failed self.fine_tuning_jobs[job_id].status = "failed" self.fine_tuning_jobs[job_id].finished_at = int(datetime.now().timestamp()) self.fine_tuning_jobs[job_id].error = str(e) print(f"Training failed: {str(e)}") import traceback print(traceback.format_exc()) @app.post("/v1/fine_tuning/jobs") async def create_fine_tuning_job(self, body: dict): """Create a fine-tuning job.""" try: # Validate required fields if "training_file" not in body: raise HTTPException(status_code=400, detail="training_file is required") if "model" not in body: raise HTTPException(status_code=400, detail="model is required") # Generate job ID job_id = f"ftjob-{uuid.uuid4().hex[:8]}" # Create job object job = FineTuningJob( id=job_id, object="fine_tuning.job", created_at=int(datetime.now().timestamp()), finished_at=None, model=body["model"], fine_tuned_model=None, organization_id="org-123", status="validating_files", # Start with validating_files hyperparameters=body.get("hyperparameters", {}), training_file=body["training_file"], trained_tokens=None, error=None, result_files=[], # Required field seed=42, # Required field ) # Store job self.fine_tuning_jobs[job_id] = job # Start training in background thread thread = threading.Thread( target=self._run_training, args=(job_id, body["training_file"], body["model"]), ) thread.start() self.training_threads[job_id] = thread return job.model_dump() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/fine_tuning/jobs") async def list_fine_tuning_jobs(self): """List all fine-tuning jobs.""" return { "object": "list", "data": [job.model_dump() for job in self.fine_tuning_jobs.values()], } @app.get("/v1/fine_tuning/jobs/{job_id}") async def get_fine_tuning_job(self, job_id: str): """Get a specific fine-tuning job.""" if job_id not in self.fine_tuning_jobs: raise HTTPException(status_code=404, detail="Job not found") return self.fine_tuning_jobs[job_id].model_dump() @app.post("/v1/fine_tuning/jobs/{job_id}/cancel") async def cancel_fine_tuning_job(self, job_id: str): """Cancel a fine-tuning job.""" if job_id not in self.fine_tuning_jobs: raise HTTPException(status_code=404, detail="Job not found") job = self.fine_tuning_jobs[job_id] if job.status not in ["created", "running"]: raise HTTPException(status_code=400, detail="Job cannot be cancelled") job.status = "cancelled" job.finished_at = int(datetime.now().timestamp()) return job.model_dump() @app.post("/v1/chat/completions") async def create_chat_completion(self, body: dict, raw_request: Request): """Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning). # noqa: E501 :param create_chat_completion_request: :type create_chat_completion_request: dict | bytes :rtype: Union[CreateChatCompletionResponse, Tuple[CreateChatCompletionResponse, int], Tuple[CreateChatCompletionResponse, int, Dict[str, str]] """ print("=== create_chat_completion ===") print("body:") pprint(body) ta = TypeAdapter(CompletionCreateParams) print("ta.validate_python...") pprint(ta.validate_python(body)) max_new_tokens = body.get("max_completion_tokens", body.get("max_tokens")) messages = body.get("messages") model_name = body.get("model") stream = body.get("stream", False) temperature = body.get("temperature") tools = body.get("tools") images = [] for message in messages: for content in message["content"]: if "type" in content and content["type"] == "image_url": image_url = content["image_url"]["url"] image = load_image(image_url) images.append(image) content["type"] = "image" del content["image_url"] elif isinstance(content, dict) and "text" in content: # Convert content to string if it's a dict with text message["content"] = content["text"] elif isinstance(content, list): # Join list items with newlines if content is a list message["content"] = "\n".join(content) images = images if images else None if model_name != self.model_name: # adapter_path = model_name # self.model.load_adapter(adapter_path) return JSONResponse(content={"error": "Model not found"}, status_code=404) prompt = self.processor.apply_chat_template( add_generation_prompt=True, conversation=messages, # documents=documents, tools=tools, tokenize=False, # Return string instead of token IDs ) print("prompt:") print(prompt) if images: inputs = self.processor(text=prompt, images=images, return_tensors="pt") else: inputs = self.processor(text=prompt, return_tensors="pt") inputs = inputs.to(self.model.device) input_ids = inputs.input_ids class GeneratorThread(Thread): """Thread to generate completions in the background.""" def __init__(self, model, **generation_kwargs): super().__init__() self.chat_completion = None self.generation_kwargs = generation_kwargs self.model = model def run(self): import torch import torch._dynamo.config try: try: self.generated_ids = self.model.generate( **self.generation_kwargs ) except torch._dynamo.exc.BackendCompilerFailed as e: print(e) print("Disabling dynamo...") torch._dynamo.config.disable = True self.generated_ids = self.model.generate( **self.generation_kwargs ) except Exception as e: print(e) print("Warning: Exception in GeneratorThread") self.generated_ids = [] def join(self, timeout=None): super().join() return self.generated_ids decode_kwargs = dict(skip_special_tokens=True) streamer = ( AsyncTextIteratorStreamer( self.processor, skip_prompt=True, **decode_kwargs, ) if stream else None ) generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, streamer=streamer, temperature=temperature, use_cache=True, ) thread = GeneratorThread(self.model, **generation_kwargs) thread.start() if stream: async def event_publisher(): i = 0 try: async for new_text in streamer: print("new_text:") print(new_text) choices: List[ChatCompletionChunkChoice] = [ ChatCompletionChunkChoice( _request_id=None, delta=ChoiceDelta( _request_id=None, content=new_text, function_call=None, refusal=None, role="assistant", tool_calls=None, ), finish_reason=None, index=0, logprobs=None, ) ] chat_completion_chunk = ChatCompletionChunk( _request_id=None, choices=choices, created=int(time.time()), id=str(i), model=model_name, object="chat.completion.chunk", service_tier=None, system_fingerprint=None, usage=None, ) yield chat_completion_chunk.model_dump_json() i += 1 except asyncio.CancelledError as e: print("Disconnected from client (via refresh/close)") raise e except Exception as e: print(f"Exception: {e}") raise e return EventSourceResponse(event_publisher()) generated_ids = thread.join() input_length = input_ids.shape[1] batch_decoded_outputs = self.processor.batch_decode( generated_ids[:, input_length:], skip_special_tokens=True, ) choices: List[ChatCompletionChoice] = [] for i, response in enumerate(batch_decoded_outputs): print("response:") print(response) # try: # response = json.loads(response) # finish_reason: str = response.get("finish_reason") # tool_calls_json = response.get("tool_calls") # tool_calls: List[ToolCall] = [] # for tool_call_json in tool_calls_json: # tool_call = ToolCall( # function=FunctionToolCallArguments( # arguments=tool_call_json.get("arguments"), # name=tool_call_json.get("name"), # ), # id=tool_call_json.get("id"), # type="function", # ) # tool_calls.append(tool_call) # message: ChatMessage = ChatMessage( # role="assistant", # tool_calls=tool_calls, # ) # except json.JSONDecodeError: # finish_reason: str = "stop" # message: ChatMessage = ChatMessage( # role="assistant", # content=response, # ) message = ChatCompletionMessage( audio=None, content=response, refusal=None, role="assistant", tool_calls=None, ) choices.append( ChatCompletionChoice( index=i, finish_reason="stop", logprobs=None, message=message, ) ) chat_completion = ChatCompletion( choices=choices, created=int(time.time()), id="1", model=model_name, object="chat.completion", service_tier=None, system_fingerprint=None, usage=None, ) return chat_completion.model_dump(mode="json") def build_app(cli_args: Dict[str, str]) -> serve.Application: """Builds the Serve app based on CLI arguments.""" return ModelDeployment.options().bind( cli_args.get("model_name"), ) # uv run serve run serve:build_app model_name="HuggingFaceTB/SmolVLM-Instruct" # uv run serve run serve:build_app model_name="unsloth/SmolLM2-135M-Instruct-bnb-4bit"