Final_Assignment_Template / serve_test.py
mjschock's picture
Enhance serve.py with fine-tuning job management, including job creation, status tracking, and training process in a separate thread. Update serve_test.py to include a test for fine-tuning functionality. Modify .gitignore to exclude model files. This update improves model training capabilities and API integration.
145385b unverified
raw
history blame contribute delete
2.91 kB
import json
import os
import time
from openai import OpenAI
# Initialize the OpenAI client with the local server
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="not-needed", # API key is not needed for local server
)
def test_chat_completion():
try:
print("Sending chat completion request...")
response = client.chat.completions.create(
model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=50,
)
# Print the response
print("\nResponse:")
print(response.choices[0].message.content)
# Print full response object for debugging
print("\nFull response object:")
print(json.dumps(response.model_dump(), indent=2))
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
print("\nFull traceback:")
print(traceback.format_exc())
def test_fine_tuning():
try:
# Create a sample training file
training_data = {
"conversations": [
{
"from": "human",
"value": "What is the capital of France?",
},
{
"from": "gpt",
"value": "The capital of France is Paris.",
},
]
}
training_file = "training_data.json"
with open(training_file, "w") as f:
json.dump(training_data, f)
print("\nCreating fine-tuning job...")
job = client.fine_tuning.jobs.create(
training_file=training_file,
model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
)
print(f"Created job: {job.id}")
# Wait for job to start
print("\nWaiting for job to start...")
time.sleep(2)
# List jobs
print("\nListing fine-tuning jobs...")
jobs = client.fine_tuning.jobs.list()
print(f"Found {len(jobs.data)} jobs")
# Get job status
print("\nGetting job status...")
job = client.fine_tuning.jobs.retrieve(job.id)
print(f"Job status: {job.status}")
# Wait for job to complete or fail
print("\nWaiting for job to complete...")
while job.status in ["created", "running"]:
time.sleep(5)
job = client.fine_tuning.jobs.retrieve(job.id)
print(f"Job status: {job.status}")
# Clean up
os.remove(training_file)
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
print("\nFull traceback:")
print(traceback.format_exc())
if __name__ == "__main__":
print("Testing chat completions endpoint...")
test_chat_completion()
print("\nTesting fine-tuning endpoints...")
test_fine_tuning()