mjschock commited on
Commit
d1da8fd
·
unverified ·
1 Parent(s): 95d9fdc

Add serve_test.py for testing chat completion functionality with the OpenAI client. Update serve.py to use FastModel for improved performance and adjust input handling for optional image processing. Include debugging output for better error tracking.

Browse files
Files changed (3) hide show
  1. serve.py +16 -6
  2. serve_test.py +40 -0
  3. train.py +2 -0
serve.py CHANGED
@@ -9,6 +9,7 @@ from typing import Any, Dict, List
9
  # isort: off
10
  from unsloth import (
11
  FastLanguageModel,
 
12
  FastVisionModel,
13
  is_bfloat16_supported,
14
  ) # noqa: E402
@@ -92,17 +93,17 @@ class ModelDeployment:
92
  ):
93
  self.model_name = model_name
94
 
95
- model, processor = FastVisionModel.from_pretrained(
96
  load_in_4bit=load_in_4bit,
97
  max_seq_length=max_seq_length,
98
  model_name=self.model_name,
99
  )
100
 
101
- with open("chat_template.txt", "r") as f:
102
- processor.chat_template = f.read()
103
- processor.tokenizer.chat_template = processor.chat_template
104
 
105
- FastVisionModel.for_inference(model) # Enable native 2x faster inference
106
 
107
  self.model = model
108
  self.processor = processor
@@ -166,12 +167,17 @@ class ModelDeployment:
166
  conversation=messages,
167
  # documents=documents,
168
  tools=tools,
 
169
  )
170
 
171
  print("prompt:")
172
  print(prompt)
173
 
174
- inputs = self.processor(text=prompt, images=images, return_tensors="pt")
 
 
 
 
175
  inputs = inputs.to(self.model.device)
176
  input_ids = inputs.input_ids
177
 
@@ -372,3 +378,7 @@ def build_app(cli_args: Dict[str, str]) -> serve.Application:
372
  return ModelDeployment.options().bind(
373
  cli_args.get("model_name"),
374
  )
 
 
 
 
 
9
  # isort: off
10
  from unsloth import (
11
  FastLanguageModel,
12
+ FastModel,
13
  FastVisionModel,
14
  is_bfloat16_supported,
15
  ) # noqa: E402
 
93
  ):
94
  self.model_name = model_name
95
 
96
+ model, processor = FastModel.from_pretrained(
97
  load_in_4bit=load_in_4bit,
98
  max_seq_length=max_seq_length,
99
  model_name=self.model_name,
100
  )
101
 
102
+ # with open("chat_template.txt", "r") as f:
103
+ # processor.chat_template = f.read()
104
+ # processor.tokenizer.chat_template = processor.chat_template
105
 
106
+ FastModel.for_inference(model) # Enable native 2x faster inference
107
 
108
  self.model = model
109
  self.processor = processor
 
167
  conversation=messages,
168
  # documents=documents,
169
  tools=tools,
170
+ tokenize=False, # Return string instead of token IDs
171
  )
172
 
173
  print("prompt:")
174
  print(prompt)
175
 
176
+ if images:
177
+ inputs = self.processor(text=prompt, images=images, return_tensors="pt")
178
+ else:
179
+ inputs = self.processor(text=prompt, return_tensors="pt")
180
+
181
  inputs = inputs.to(self.model.device)
182
  input_ids = inputs.input_ids
183
 
 
378
  return ModelDeployment.options().bind(
379
  cli_args.get("model_name"),
380
  )
381
+
382
+
383
+ # uv run serve run serve:build_app model_name="HuggingFaceTB/SmolVLM-Instruct"
384
+ # uv run serve run serve:build_app model_name="unsloth/SmolLM2-135M-Instruct-bnb-4bit"
serve_test.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from openai import OpenAI
4
+
5
+ # Initialize the OpenAI client with the local server
6
+ client = OpenAI(
7
+ base_url="http://localhost:8000/v1",
8
+ api_key="not-needed", # API key is not needed for local server
9
+ )
10
+
11
+
12
+ def test_chat_completion():
13
+ try:
14
+ print("Sending chat completion request...")
15
+ response = client.chat.completions.create(
16
+ model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
17
+ messages=[{"role": "user", "content": "Hello"}],
18
+ temperature=0.7,
19
+ max_tokens=50,
20
+ )
21
+
22
+ # Print the response
23
+ print("\nResponse:")
24
+ print(response.choices[0].message.content)
25
+
26
+ # Print full response object for debugging
27
+ print("\nFull response object:")
28
+ print(json.dumps(response.model_dump(), indent=2))
29
+
30
+ except Exception as e:
31
+ print(f"Error occurred: {str(e)}")
32
+ import traceback
33
+
34
+ print("\nFull traceback:")
35
+ print(traceback.format_exc())
36
+
37
+
38
+ if __name__ == "__main__":
39
+ print("Testing chat completions endpoint...")
40
+ test_chat_completion()
train.py CHANGED
@@ -412,3 +412,5 @@ Please format your response as a JSON object with two keys:
412
 
413
  if __name__ == "__main__":
414
  main()
 
 
 
412
 
413
  if __name__ == "__main__":
414
  main()
415
+
416
+ # uv run python train.py