Andy Lee commited on
Commit
fc23f51
·
1 Parent(s): f22dc3b

feat: support qwen and openrouters

Browse files
Files changed (4) hide show
  1. app.py +2 -18
  2. benchmark.py +4 -21
  3. config.py +78 -11
  4. main.py +2 -2
app.py CHANGED
@@ -8,11 +8,8 @@ from pathlib import Path
8
 
9
  from geo_bot import GeoBot, AGENT_PROMPT_TEMPLATE
10
  from benchmark import MapGuesserBenchmark
11
- from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM
12
- from langchain_openai import ChatOpenAI
13
- from langchain_anthropic import ChatAnthropic
14
- from langchain_google_genai import ChatGoogleGenerativeAI
15
- from hf_chat import HuggingFaceChat
16
 
17
  # Simple API key setup
18
  if "OPENAI_API_KEY" in st.secrets:
@@ -38,19 +35,6 @@ def get_available_datasets():
38
  return datasets if datasets else ["default"]
39
 
40
 
41
- def get_model_class(class_name):
42
- if class_name == "ChatOpenAI":
43
- return ChatOpenAI
44
- elif class_name == "ChatAnthropic":
45
- return ChatAnthropic
46
- elif class_name == "ChatGoogleGenerativeAI":
47
- return ChatGoogleGenerativeAI
48
- elif class_name == "HuggingFaceChat":
49
- return HuggingFaceChat
50
- else:
51
- raise ValueError(f"Unknown model class: {class_name}")
52
-
53
-
54
  # UI Setup
55
  st.set_page_config(page_title="🧠 Omniscient - AI Geographic Analysis", layout="wide")
56
  st.title("🧠 Omniscient")
 
8
 
9
  from geo_bot import GeoBot, AGENT_PROMPT_TEMPLATE
10
  from benchmark import MapGuesserBenchmark
11
+ from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM, get_model_class
12
+
 
 
 
13
 
14
  # Simple API key setup
15
  if "OPENAI_API_KEY" in st.secrets:
 
35
  return datasets if datasets else ["default"]
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # UI Setup
39
  st.set_page_config(page_title="🧠 Omniscient - AI Geographic Analysis", layout="wide")
40
  st.title("🧠 Omniscient")
benchmark.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  import math
10
 
11
  from geo_bot import GeoBot
12
- from config import get_data_paths, MODELS_CONFIG, SUCCESS_THRESHOLD_KM
13
 
14
 
15
  class MapGuesserBenchmark:
@@ -29,25 +29,6 @@ class MapGuesserBenchmark:
29
  except Exception:
30
  return []
31
 
32
- def get_model_class(self, model_name: str):
33
- config = MODELS_CONFIG.get(model_name)
34
- if not config:
35
- raise ValueError(f"Unknown model: {model_name}")
36
- class_name, model_class_name = config["class"], config["model_name"]
37
- if class_name == "ChatOpenAI":
38
- from langchain_openai import ChatOpenAI
39
-
40
- return ChatOpenAI, model_class_name
41
- if class_name == "ChatAnthropic":
42
- from langchain_anthropic import ChatAnthropic
43
-
44
- return ChatAnthropic, model_class_name
45
- if class_name == "ChatGoogleGenerativeAI":
46
- from langchain_google_genai import ChatGoogleGenerativeAI
47
-
48
- return ChatGoogleGenerativeAI, model_class_name
49
- raise ValueError(f"Unknown model class: {class_name}")
50
-
51
  def calculate_distance(
52
  self, true_coords: Dict, predicted_coords: Optional[Tuple[float, float]]
53
  ) -> Optional[float]:
@@ -99,7 +80,9 @@ class MapGuesserBenchmark:
99
  all_results = []
100
  for model_name in models_to_test:
101
  print(f"\n🤖 Testing model: {model_name}")
102
- model_class, model_class_name = self.get_model_class(model_name)
 
 
103
 
104
  try:
105
  with GeoBot(
 
9
  import math
10
 
11
  from geo_bot import GeoBot
12
+ from config import get_data_paths, MODELS_CONFIG, SUCCESS_THRESHOLD_KM, get_model_class
13
 
14
 
15
  class MapGuesserBenchmark:
 
29
  except Exception:
30
  return []
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def calculate_distance(
33
  self, true_coords: Dict, predicted_coords: Optional[Tuple[float, float]]
34
  ) -> Optional[float]:
 
80
  all_results = []
81
  for model_name in models_to_test:
82
  print(f"\n🤖 Testing model: {model_name}")
83
+ model_config = MODELS_CONFIG[model_name]
84
+ model_class = get_model_class(model_config["class"])
85
+ model_class_name = model_config["model_name"]
86
 
87
  try:
88
  with GeoBot(
config.py CHANGED
@@ -1,5 +1,10 @@
1
  # Configuration file for MapCrunch benchmark
2
 
 
 
 
 
 
3
  SUCCESS_THRESHOLD_KM = 100
4
 
5
  # MapCrunch settings
@@ -38,10 +43,15 @@ MODELS_CONFIG = {
38
  "model_name": "gpt-4o-mini",
39
  "description": "OpenAI GPT-4o Mini",
40
  },
41
- "claude-3.5-sonnet": {
 
 
 
 
 
42
  "class": "ChatAnthropic",
43
- "model_name": "claude-3-5-sonnet-20240620",
44
- "description": "Anthropic Claude 3.5 Sonnet",
45
  },
46
  "gemini-1.5-pro": {
47
  "class": "ChatGoogleGenerativeAI",
@@ -58,19 +68,76 @@ MODELS_CONFIG = {
58
  "model_name": "gemini-2.5-pro-preview-06-05",
59
  "description": "Google Gemini 2.5 Pro",
60
  },
61
- "qwen2-vl-7b": {
62
- "class": "HuggingFaceChat",
63
- "model_name": "Qwen/Qwen2-VL-7B-Instruct",
64
- "description": "Qwen2-VL 7B (older but API supported)",
65
  },
66
- "qwen2-vl-2b": {
67
- "class": "HuggingFaceChat",
68
- "model_name": "Qwen/Qwen2-VL-2B-Instruct",
69
- "description": "Qwen2-VL 2B (faster, API supported)",
 
 
 
 
 
 
 
 
 
 
70
  },
71
  }
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Data paths - now supports named datasets
75
  def get_data_paths(dataset_name: str = "default"):
76
  """Get data paths for a specific dataset"""
 
1
  # Configuration file for MapCrunch benchmark
2
 
3
+ from pydantic import SecretStr, Field
4
+ from typing import Optional
5
+ import os
6
+
7
+
8
  SUCCESS_THRESHOLD_KM = 100
9
 
10
  # MapCrunch settings
 
43
  "model_name": "gpt-4o-mini",
44
  "description": "OpenAI GPT-4o Mini",
45
  },
46
+ "claude-3-7-sonnet": {
47
+ "class": "ChatAnthropic",
48
+ "model_name": "claude-3-7-sonnet-20250219",
49
+ "description": "Anthropic Claude 3.7 Sonnet",
50
+ },
51
+ "claude-4-sonnet": {
52
  "class": "ChatAnthropic",
53
+ "model_name": "claude-4-sonnet-20250514",
54
+ "description": "Anthropic Claude 4 Sonnet",
55
  },
56
  "gemini-1.5-pro": {
57
  "class": "ChatGoogleGenerativeAI",
 
68
  "model_name": "gemini-2.5-pro-preview-06-05",
69
  "description": "Google Gemini 2.5 Pro",
70
  },
71
+ "qwen-vl-max": {
72
+ "class": "OpenRouter",
73
+ "model_name": "qwen/qwen-vl-max",
74
+ "description": "Qwen VL Max - OpenRouter (Best Performance)",
75
  },
76
+ "qwen2.5-vl-32b-free": {
77
+ "class": "OpenRouter",
78
+ "model_name": "qwen/qwen2.5-vl-32b-instruct:free",
79
+ "description": "Qwen2.5 VL 32B - OpenRouter (FREE!)",
80
+ },
81
+ "qwen2.5-vl-7b": {
82
+ "class": "OpenRouter",
83
+ "model_name": "qwen/qwen2.5-vl-7b-instruct",
84
+ "description": "Qwen2.5 VL 7B - OpenRouter",
85
+ },
86
+ "qwen2.5-vl-3b": {
87
+ "class": "OpenRouter",
88
+ "model_name": "qwen/qwen2.5-vl-3b-instruct",
89
+ "description": "Qwen2.5 VL 3B - OpenRouter (Fastest)",
90
  },
91
  }
92
 
93
 
94
+ def get_model_class(class_name):
95
+ """Get actual model class from string name"""
96
+ if class_name == "ChatOpenAI":
97
+ from langchain_openai import ChatOpenAI
98
+
99
+ return ChatOpenAI
100
+ elif class_name == "ChatAnthropic":
101
+ from langchain_anthropic import ChatAnthropic
102
+
103
+ return ChatAnthropic
104
+ elif class_name == "ChatGoogleGenerativeAI":
105
+ from langchain_google_genai import ChatGoogleGenerativeAI
106
+
107
+ return ChatGoogleGenerativeAI
108
+ elif class_name == "HuggingFaceChat":
109
+ from hf_chat import HuggingFaceChat
110
+
111
+ return HuggingFaceChat
112
+ elif class_name == "OpenRouter":
113
+ from langchain_openai import ChatOpenAI
114
+ from langchain_core.utils.utils import secret_from_env
115
+
116
+ # LangChain does not support OpenRouter directly, so we need to create a custom class
117
+ # See https://github.com/langchain-ai/langchain/discussions/27964.
118
+ class ChatOpenRouter(ChatOpenAI):
119
+ openai_api_key: Optional[SecretStr] = Field(
120
+ alias="api_key",
121
+ default_factory=secret_from_env("OPENROUTER_API_KEY", default=None),
122
+ )
123
+
124
+ @property
125
+ def lc_secrets(self) -> dict[str, str]:
126
+ return {"openai_api_key": "OPENROUTER_API_KEY"}
127
+
128
+ def __init__(self, openai_api_key: Optional[str] = None, **kwargs):
129
+ openai_api_key = openai_api_key or os.environ.get("OPENROUTER_API_KEY")
130
+ super().__init__(
131
+ base_url="https://openrouter.ai/api/v1",
132
+ api_key=SecretStr(openai_api_key) if openai_api_key else None,
133
+ **kwargs,
134
+ )
135
+
136
+ return ChatOpenRouter
137
+ else:
138
+ raise ValueError(f"Unknown model class: {class_name}")
139
+
140
+
141
  # Data paths - now supports named datasets
142
  def get_data_paths(dataset_name: str = "default"):
143
  """Get data paths for a specific dataset"""
main.py CHANGED
@@ -10,7 +10,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
10
  from geo_bot import GeoBot
11
  from benchmark import MapGuesserBenchmark
12
  from data_collector import DataCollector
13
- from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM
14
 
15
 
16
  def agent_mode(
@@ -48,7 +48,7 @@ def agent_mode(
48
  print(f"Will run on {len(test_samples)} samples from dataset '{dataset_name}'.")
49
 
50
  config = MODELS_CONFIG.get(model_name)
51
- model_class = globals()[config["class"]]
52
  model_instance_name = config["model_name"]
53
 
54
  benchmark_helper = MapGuesserBenchmark(dataset_name=dataset_name, headless=True)
 
10
  from geo_bot import GeoBot
11
  from benchmark import MapGuesserBenchmark
12
  from data_collector import DataCollector
13
+ from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM, get_model_class
14
 
15
 
16
  def agent_mode(
 
48
  print(f"Will run on {len(test_samples)} samples from dataset '{dataset_name}'.")
49
 
50
  config = MODELS_CONFIG.get(model_name)
51
+ model_class = get_model_class(config["class"])
52
  model_instance_name = config["model_name"]
53
 
54
  benchmark_helper = MapGuesserBenchmark(dataset_name=dataset_name, headless=True)