mjschock commited on
Commit
4395ceb
·
unverified ·
1 Parent(s): d1da8fd

Enhance serve.py to handle additional content types by converting dictionary text and joining list items. Update train.py to replace FastLanguageModel with FastModel and LiteLLMModel, streamline model loading, and adjust dataset preparation logic. Modify config.yaml to change max_samples for testing and add provider information for model configuration.

Browse files
Files changed (3) hide show
  1. conf/config.yaml +3 -1
  2. serve.py +6 -0
  3. train.py +18 -15
conf/config.yaml CHANGED
@@ -4,7 +4,9 @@ defaults:
4
  # Model configuration
5
  model:
6
  name: "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
 
7
  max_seq_length: 2048 # Auto supports RoPE Scaling internally
 
8
  dtype: null # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
9
  load_in_4bit: true # Use 4bit quantization to reduce memory usage
10
 
@@ -77,5 +79,5 @@ test_dataset:
77
  name: "gaia-benchmark/GAIA"
78
  config: "2023_level1" # Use level 1 questions for testing
79
  split: "test" # Use test split for testing
80
- max_samples: 10 # Number of samples to test on
81
  max_length: 2048 # Maximum sequence length for testing
 
4
  # Model configuration
5
  model:
6
  name: "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
7
+ # name: "HuggingFaceTB/SmolLM2-135M-Instruct"
8
  max_seq_length: 2048 # Auto supports RoPE Scaling internally
9
+ provider: "openai"
10
  dtype: null # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
11
  load_in_4bit: true # Use 4bit quantization to reduce memory usage
12
 
 
79
  name: "gaia-benchmark/GAIA"
80
  config: "2023_level1" # Use level 1 questions for testing
81
  split: "test" # Use test split for testing
82
+ max_samples: 3 # Number of samples to test on
83
  max_length: 2048 # Maximum sequence length for testing
serve.py CHANGED
@@ -153,6 +153,12 @@ class ModelDeployment:
153
 
154
  content["type"] = "image"
155
  del content["image_url"]
 
 
 
 
 
 
156
 
157
  images = images if images else None
158
 
 
153
 
154
  content["type"] = "image"
155
  del content["image_url"]
156
+ elif isinstance(content, dict) and "text" in content:
157
+ # Convert content to string if it's a dict with text
158
+ message["content"] = content["text"]
159
+ elif isinstance(content, list):
160
+ # Join list items with newlines if content is a list
161
+ message["content"] = "\n".join(content)
162
 
163
  images = images if images else None
164
 
train.py CHANGED
@@ -23,7 +23,7 @@ import hydra
23
  from omegaconf import DictConfig, OmegaConf
24
 
25
  # isort: off
26
- from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
27
  from unsloth.chat_templates import get_chat_template # noqa: E402
28
 
29
  # isort: on
@@ -39,7 +39,7 @@ from datasets import (
39
  load_dataset,
40
  )
41
  from peft import PeftModel
42
- from smolagents import CodeAgent, Model, TransformersModel, VLLMModel
43
  from smolagents.monitoring import LogLevel
44
  from transformers import (
45
  AutoModelForCausalLM,
@@ -97,7 +97,7 @@ def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]:
97
  """Load and configure the model."""
98
  logger.info("Loading model and tokenizer...")
99
  try:
100
- model, tokenizer = FastLanguageModel.from_pretrained(
101
  model_name=cfg.model.name,
102
  max_seq_length=cfg.model.max_seq_length,
103
  dtype=cfg.model.dtype,
@@ -106,7 +106,7 @@ def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]:
106
  logger.info("Base model loaded successfully")
107
 
108
  # Configure LoRA
109
- model = FastLanguageModel.get_peft_model(
110
  model,
111
  r=cfg.peft.r,
112
  target_modules=cfg.peft.target_modules,
@@ -242,19 +242,19 @@ def main(cfg: DictConfig) -> None:
242
  logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
243
 
244
  # Install dependencies
245
- install_dependencies()
246
 
247
- # Load model and tokenizer
248
- model, tokenizer = load_model(cfg)
 
 
249
 
250
- # Load and prepare dataset
251
- dataset, tokenizer = load_and_format_dataset(tokenizer, cfg)
252
 
253
- # Create trainer
254
- trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg)
255
 
256
- # Train if requested
257
- if cfg.train:
258
  logger.info("Starting training...")
259
  trainer.train()
260
 
@@ -304,8 +304,11 @@ def main(cfg: DictConfig) -> None:
304
  torch.cuda.empty_cache()
305
 
306
  # Initialize model
307
- model: Model = Model(
308
- model_id=cfg.model.name,
 
 
 
309
  # model_id=cfg.output.dir,
310
  )
311
 
 
23
  from omegaconf import DictConfig, OmegaConf
24
 
25
  # isort: off
26
+ from unsloth import FastLanguageModel, FastModel, is_bfloat16_supported # noqa: E402
27
  from unsloth.chat_templates import get_chat_template # noqa: E402
28
 
29
  # isort: on
 
39
  load_dataset,
40
  )
41
  from peft import PeftModel
42
+ from smolagents import CodeAgent, LiteLLMModel, Model, TransformersModel, VLLMModel
43
  from smolagents.monitoring import LogLevel
44
  from transformers import (
45
  AutoModelForCausalLM,
 
97
  """Load and configure the model."""
98
  logger.info("Loading model and tokenizer...")
99
  try:
100
+ model, tokenizer = FastModel.from_pretrained(
101
  model_name=cfg.model.name,
102
  max_seq_length=cfg.model.max_seq_length,
103
  dtype=cfg.model.dtype,
 
106
  logger.info("Base model loaded successfully")
107
 
108
  # Configure LoRA
109
+ model = FastModel.get_peft_model(
110
  model,
111
  r=cfg.peft.r,
112
  target_modules=cfg.peft.target_modules,
 
242
  logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
243
 
244
  # Install dependencies
245
+ # install_dependencies()
246
 
247
+ # Train if requested
248
+ if cfg.train:
249
+ # Load model and tokenizer
250
+ model, tokenizer = load_model(cfg)
251
 
252
+ # Load and prepare dataset
253
+ dataset, tokenizer = load_and_format_dataset(tokenizer, cfg)
254
 
255
+ # Create trainer
256
+ trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg)
257
 
 
 
258
  logger.info("Starting training...")
259
  trainer.train()
260
 
 
304
  torch.cuda.empty_cache()
305
 
306
  # Initialize model
307
+ model: Model = LiteLLMModel(
308
+ api_base="http://localhost:8000/v1",
309
+ api_key="not-needed",
310
+ model_id=f"{cfg.model.provider}/{cfg.model.name}",
311
+ # model_id=cfg.model.name,
312
  # model_id=cfg.output.dir,
313
  )
314