Andy Lee commited on
Commit
4d37e51
·
1 Parent(s): 23ee129

feat: more models, including qwen

Browse files
Files changed (4) hide show
  1. app.py +104 -25
  2. config.py +26 -0
  3. geo_bot.py +32 -12
  4. hf_chat.py +142 -0
app.py CHANGED
@@ -136,63 +136,138 @@ if start_button:
136
  # --- Inner agent exploration loop ---
137
  history = []
138
  final_guess = None
 
139
 
140
  for step in range(steps_per_sample):
141
  step_num = step + 1
142
  reasoning_placeholder.info(
143
- f"Thinking... (Step {step_num}/{steps_per_sample})"
144
  )
145
  action_placeholder.empty()
146
 
147
  # Observe and label arrows
148
  bot.controller.label_arrows_on_screen()
149
  screenshot_bytes = bot.controller.take_street_view_screenshot()
 
 
150
  image_placeholder.image(
151
- screenshot_bytes, caption=f"Step {step_num} View", use_column_width=True
 
 
152
  )
153
 
154
  # Update history
155
- history.append(
156
- {
157
- "image_b64": bot.pil_to_base64(
158
- Image.open(BytesIO(screenshot_bytes))
159
- ),
160
- "action": "N/A",
161
- }
162
- )
163
 
164
  # Think
 
 
 
 
 
 
 
165
  prompt = AGENT_PROMPT_TEMPLATE.format(
166
  remaining_steps=steps_per_sample - step,
167
- history_text="\n".join(
168
- [f"Step {j + 1}: {h['action']}" for j, h in enumerate(history)]
169
- ),
170
- available_actions=json.dumps(bot.controller.get_available_actions()),
171
  )
 
 
 
 
 
 
 
 
 
 
 
172
  message = bot._create_message_with_history(
173
  prompt, [h["image_b64"] for h in history]
174
  )
 
 
175
  response = bot.model.invoke(message)
176
  decision = bot._parse_agent_response(response)
177
 
178
  if not decision: # Fallback
179
  decision = {
180
  "action_details": {"action": "PAN_RIGHT"},
181
- "reasoning": "Default recovery.",
182
  }
183
 
184
  action = decision.get("action_details", {}).get("action")
185
  history[-1]["action"] = action
186
-
187
- reasoning_placeholder.info(
188
- f"**AI Reasoning:**\n\n{decision.get('reasoning', 'N/A')}"
 
 
189
  )
190
- action_placeholder.success(f"**AI Action:** `{action}`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  # Force a GUESS on the last step
193
  if step_num == steps_per_sample and action != "GUESS":
194
- st.warning("Max steps reached. Forcing a GUESS action.")
195
  action = "GUESS"
 
 
 
 
 
 
 
 
196
 
197
  # Act
198
  if action == "GUESS":
@@ -204,18 +279,22 @@ if start_button:
204
  final_guess = (lat, lon)
205
  else:
206
  st.error(
207
- "GUESS action was missing coordinates. Guess failed for this sample."
208
  )
209
  break # End exploration for the current sample
210
 
211
  elif action == "MOVE_FORWARD":
212
- bot.controller.move("forward")
 
213
  elif action == "MOVE_BACKWARD":
214
- bot.controller.move("backward")
 
215
  elif action == "PAN_LEFT":
216
- bot.controller.pan_view("left")
 
217
  elif action == "PAN_RIGHT":
218
- bot.controller.pan_view("right")
 
219
 
220
  time.sleep(1) # A brief pause between steps for better visualization
221
 
 
136
  # --- Inner agent exploration loop ---
137
  history = []
138
  final_guess = None
139
+ step_history_container = st.container()
140
 
141
  for step in range(steps_per_sample):
142
  step_num = step + 1
143
  reasoning_placeholder.info(
144
+ f"🤔 Thinking... (Step {step_num}/{steps_per_sample})"
145
  )
146
  action_placeholder.empty()
147
 
148
  # Observe and label arrows
149
  bot.controller.label_arrows_on_screen()
150
  screenshot_bytes = bot.controller.take_street_view_screenshot()
151
+
152
+ # Current view
153
  image_placeholder.image(
154
+ screenshot_bytes,
155
+ caption=f"🔍 Step {step_num} - What AI Sees Now",
156
+ use_column_width=True,
157
  )
158
 
159
  # Update history
160
+ current_step_data = {
161
+ "image_b64": bot.pil_to_base64(Image.open(BytesIO(screenshot_bytes))),
162
+ "action": "N/A",
163
+ "screenshot_bytes": screenshot_bytes,
164
+ "step_num": step_num,
165
+ }
166
+ history.append(current_step_data)
 
167
 
168
  # Think
169
+ available_actions = bot.controller.get_available_actions()
170
+ history_text = "\n".join(
171
+ [f"Step {j + 1}: {h['action']}" for j, h in enumerate(history[:-1])]
172
+ )
173
+ if not history_text:
174
+ history_text = "No history yet. This is the first step."
175
+
176
  prompt = AGENT_PROMPT_TEMPLATE.format(
177
  remaining_steps=steps_per_sample - step,
178
+ history_text=history_text,
179
+ available_actions=json.dumps(available_actions),
 
 
180
  )
181
+
182
+ # Show what AI is considering
183
+ with reasoning_placeholder:
184
+ st.info("🧠 **AI is analyzing the situation...**")
185
+ with st.expander("🔍 Available Actions", expanded=False):
186
+ st.json(available_actions)
187
+ with st.expander("📝 Context Being Considered", expanded=False):
188
+ st.text_area(
189
+ "History Context:", history_text, height=100, disabled=True
190
+ )
191
+
192
  message = bot._create_message_with_history(
193
  prompt, [h["image_b64"] for h in history]
194
  )
195
+
196
+ # Get AI response
197
  response = bot.model.invoke(message)
198
  decision = bot._parse_agent_response(response)
199
 
200
  if not decision: # Fallback
201
  decision = {
202
  "action_details": {"action": "PAN_RIGHT"},
203
+ "reasoning": "⚠️ Response parsing failed. Using default recovery action.",
204
  }
205
 
206
  action = decision.get("action_details", {}).get("action")
207
  history[-1]["action"] = action
208
+ history[-1]["reasoning"] = decision.get("reasoning", "N/A")
209
+ history[-1]["raw_response"] = (
210
+ response.content[:500] + "..."
211
+ if len(response.content) > 500
212
+ else response.content
213
  )
214
+
215
+ # Display AI's decision process
216
+ reasoning_placeholder.success("✅ **AI Decision Made!**")
217
+
218
+ with action_placeholder:
219
+ st.success(f"🎯 **AI Action:** `{action}`")
220
+
221
+ # Detailed reasoning display
222
+ with st.expander("🧠 AI's Detailed Thinking Process", expanded=True):
223
+ col_reason, col_raw = st.columns([2, 1])
224
+
225
+ with col_reason:
226
+ st.markdown("**🤔 AI's Reasoning:**")
227
+ st.info(decision.get("reasoning", "N/A"))
228
+
229
+ if action == "GUESS":
230
+ lat = decision.get("action_details", {}).get("lat")
231
+ lon = decision.get("action_details", {}).get("lon")
232
+ if lat and lon:
233
+ st.success(f"📍 **Final Guess:** {lat:.4f}, {lon:.4f}")
234
+
235
+ with col_raw:
236
+ st.markdown("**🔤 Raw AI Response:**")
237
+ st.text_area(
238
+ "Full Response:",
239
+ history[-1]["raw_response"],
240
+ height=200,
241
+ disabled=True,
242
+ key=f"raw_response_{step_num}",
243
+ )
244
+
245
+ # Store step in history display
246
+ with step_history_container:
247
+ with st.expander(f"📚 Step {step_num} History", expanded=False):
248
+ hist_col1, hist_col2 = st.columns([1, 2])
249
+ with hist_col1:
250
+ st.image(
251
+ screenshot_bytes, caption=f"Step {step_num} View", width=200
252
+ )
253
+ with hist_col2:
254
+ st.write(f"**Action:** {action}")
255
+ st.write(
256
+ f"**Reasoning:** {decision.get('reasoning', 'N/A')[:150]}..."
257
+ )
258
 
259
  # Force a GUESS on the last step
260
  if step_num == steps_per_sample and action != "GUESS":
261
+ st.warning("Max steps reached. Forcing a GUESS action.")
262
  action = "GUESS"
263
+ # Force coordinates if missing
264
+ if not decision.get("action_details", {}).get("lat"):
265
+ st.error("❌ AI didn't provide coordinates. Using fallback guess.")
266
+ decision["action_details"] = {
267
+ "action": "GUESS",
268
+ "lat": 0.0,
269
+ "lon": 0.0,
270
+ }
271
 
272
  # Act
273
  if action == "GUESS":
 
279
  final_guess = (lat, lon)
280
  else:
281
  st.error(
282
+ "GUESS action was missing coordinates. Guess failed for this sample."
283
  )
284
  break # End exploration for the current sample
285
 
286
  elif action == "MOVE_FORWARD":
287
+ with st.spinner("🚶 Moving forward..."):
288
+ bot.controller.move("forward")
289
  elif action == "MOVE_BACKWARD":
290
+ with st.spinner("🔄 Moving backward..."):
291
+ bot.controller.move("backward")
292
  elif action == "PAN_LEFT":
293
+ with st.spinner("⬅️ Panning left..."):
294
+ bot.controller.pan_view("left")
295
  elif action == "PAN_RIGHT":
296
+ with st.spinner("➡️ Panning right..."):
297
+ bot.controller.pan_view("right")
298
 
299
  time.sleep(1) # A brief pause between steps for better visualization
300
 
config.py CHANGED
@@ -31,18 +31,44 @@ MODELS_CONFIG = {
31
  "gpt-4o": {
32
  "class": "ChatOpenAI",
33
  "model_name": "gpt-4o",
 
 
 
 
 
 
 
 
34
  },
35
  "claude-3.5-sonnet": {
36
  "class": "ChatAnthropic",
37
  "model_name": "claude-3-5-sonnet-20240620",
 
 
38
  },
39
  "gemini-1.5-pro": {
40
  "class": "ChatGoogleGenerativeAI",
41
  "model_name": "gemini-1.5-pro-latest",
 
 
42
  },
43
  "gemini-2.5-pro": {
44
  "class": "ChatGoogleGenerativeAI",
45
  "model_name": "gemini-2.5-pro-preview-06-05",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  },
47
  }
48
 
 
31
  "gpt-4o": {
32
  "class": "ChatOpenAI",
33
  "model_name": "gpt-4o",
34
+ "api_key_env": "OPENAI_API_KEY",
35
+ "description": "OpenAI GPT-4o",
36
+ },
37
+ "gpt-4o-mini": {
38
+ "class": "ChatOpenAI",
39
+ "model_name": "gpt-4o-mini",
40
+ "api_key_env": "OPENAI_API_KEY",
41
+ "description": "OpenAI GPT-4o Mini (cheaper)",
42
  },
43
  "claude-3.5-sonnet": {
44
  "class": "ChatAnthropic",
45
  "model_name": "claude-3-5-sonnet-20240620",
46
+ "api_key_env": "ANTHROPIC_API_KEY",
47
+ "description": "Anthropic Claude 3.5 Sonnet",
48
  },
49
  "gemini-1.5-pro": {
50
  "class": "ChatGoogleGenerativeAI",
51
  "model_name": "gemini-1.5-pro-latest",
52
+ "api_key_env": "GOOGLE_API_KEY",
53
+ "description": "Google Gemini 1.5 Pro",
54
  },
55
  "gemini-2.5-pro": {
56
  "class": "ChatGoogleGenerativeAI",
57
  "model_name": "gemini-2.5-pro-preview-06-05",
58
+ "api_key_env": "GOOGLE_API_KEY",
59
+ "description": "Google Gemini 2.5 Pro",
60
+ },
61
+ "qwen2-vl-72b": {
62
+ "class": "HuggingFaceChat",
63
+ "model_name": "Qwen/Qwen2-VL-72B-Instruct",
64
+ "api_key_env": "HUGGINGFACE_API_KEY",
65
+ "description": "Qwen2-VL 72B (via HF Inference API)",
66
+ },
67
+ "qwen2-vl-7b": {
68
+ "class": "HuggingFaceChat",
69
+ "model_name": "Qwen/Qwen2-VL-7B-Instruct",
70
+ "api_key_env": "HUGGINGFACE_API_KEY",
71
+ "description": "Qwen2-VL 7B (via HF Inference API)",
72
  },
73
  }
74
 
geo_bot.py CHANGED
@@ -11,10 +11,11 @@ from langchain_openai import ChatOpenAI
11
  from langchain_anthropic import ChatAnthropic
12
  from langchain_google_genai import ChatGoogleGenerativeAI
13
 
 
 
14
  from mapcrunch_controller import MapCrunchController
15
 
16
  # The "Golden" Prompt (v6): Combines clear mechanics with robust strategic principles.
17
-
18
  AGENT_PROMPT_TEMPLATE = """
19
  **Mission:** You are an expert geo-location agent. Your goal is to find clues to determine your location within a limited number of steps.
20
 
@@ -68,11 +69,20 @@ class GeoBot:
68
  ):
69
  # Initialize model with temperature parameter
70
  model_kwargs = {
71
- "model": model_name,
72
  "temperature": temperature,
73
  }
74
 
75
- self.model = model(**model_kwargs)
 
 
 
 
 
 
 
 
 
 
76
  self.model_name = model_name
77
  self.temperature = temperature
78
  self.use_selenium = use_selenium
@@ -90,6 +100,7 @@ class GeoBot:
90
  ) -> List[HumanMessage]:
91
  """Creates a message for the LLM that includes text and a sequence of images."""
92
  content = [{"type": "text", "text": prompt}]
 
93
  # Add the JSON format instructions right after the main prompt text
94
  content.append(
95
  {
@@ -145,7 +156,6 @@ class GeoBot:
145
  print(f"\n--- Step {max_steps - step + 1}/{max_steps} ---")
146
 
147
  self.controller.setup_clean_environment()
148
-
149
  self.controller.label_arrows_on_screen()
150
 
151
  screenshot_bytes = self.controller.take_street_view_screenshot()
@@ -178,17 +188,22 @@ class GeoBot:
178
  available_actions=json.dumps(available_actions),
179
  )
180
 
181
- message = self._create_message_with_history(prompt, image_b64_for_prompt)
182
- response = self.model.invoke(message)
183
-
184
- decision = self._parse_agent_response(response)
 
 
 
 
 
185
 
186
  if not decision:
187
  print(
188
- "Response parsing failed. Using default recovery action: PAN_RIGHT."
189
  )
190
  decision = {
191
- "reasoning": "Recovery due to parsing failure.",
192
  "action_details": {"action": "PAN_RIGHT"},
193
  }
194
 
@@ -219,8 +234,13 @@ class GeoBot:
219
  def analyze_image(self, image: Image.Image) -> Optional[Tuple[float, float]]:
220
  image_b64 = self.pil_to_base64(image)
221
  message = self._create_llm_message(BENCHMARK_PROMPT, image_b64)
222
- response = self.model.invoke(message)
223
- print(f"\nLLM Response:\n{response.content}")
 
 
 
 
 
224
 
225
  content = response.content.strip()
226
  last_line = ""
 
11
  from langchain_anthropic import ChatAnthropic
12
  from langchain_google_genai import ChatGoogleGenerativeAI
13
 
14
+ from hf_chat import HuggingFaceChat
15
+
16
  from mapcrunch_controller import MapCrunchController
17
 
18
  # The "Golden" Prompt (v6): Combines clear mechanics with robust strategic principles.
 
19
  AGENT_PROMPT_TEMPLATE = """
20
  **Mission:** You are an expert geo-location agent. Your goal is to find clues to determine your location within a limited number of steps.
21
 
 
69
  ):
70
  # Initialize model with temperature parameter
71
  model_kwargs = {
 
72
  "temperature": temperature,
73
  }
74
 
75
+ # Handle different model types
76
+ if model == HuggingFaceChat and HuggingFaceChat is not None:
77
+ model_kwargs["model"] = model_name
78
+ else:
79
+ model_kwargs["model"] = model_name
80
+
81
+ try:
82
+ self.model = model(**model_kwargs)
83
+ except Exception as e:
84
+ raise ValueError(f"Failed to initialize model {model_name}: {e}")
85
+
86
  self.model_name = model_name
87
  self.temperature = temperature
88
  self.use_selenium = use_selenium
 
100
  ) -> List[HumanMessage]:
101
  """Creates a message for the LLM that includes text and a sequence of images."""
102
  content = [{"type": "text", "text": prompt}]
103
+
104
  # Add the JSON format instructions right after the main prompt text
105
  content.append(
106
  {
 
156
  print(f"\n--- Step {max_steps - step + 1}/{max_steps} ---")
157
 
158
  self.controller.setup_clean_environment()
 
159
  self.controller.label_arrows_on_screen()
160
 
161
  screenshot_bytes = self.controller.take_street_view_screenshot()
 
188
  available_actions=json.dumps(available_actions),
189
  )
190
 
191
+ try:
192
+ message = self._create_message_with_history(
193
+ prompt, image_b64_for_prompt
194
+ )
195
+ response = self.model.invoke(message)
196
+ decision = self._parse_agent_response(response)
197
+ except Exception as e:
198
+ print(f"Error during model invocation: {e}")
199
+ decision = None
200
 
201
  if not decision:
202
  print(
203
+ "Response parsing failed or model error. Using default recovery action: PAN_RIGHT."
204
  )
205
  decision = {
206
+ "reasoning": "Recovery due to parsing failure or model error.",
207
  "action_details": {"action": "PAN_RIGHT"},
208
  }
209
 
 
234
  def analyze_image(self, image: Image.Image) -> Optional[Tuple[float, float]]:
235
  image_b64 = self.pil_to_base64(image)
236
  message = self._create_llm_message(BENCHMARK_PROMPT, image_b64)
237
+
238
+ try:
239
+ response = self.model.invoke(message)
240
+ print(f"\nLLM Response:\n{response.content}")
241
+ except Exception as e:
242
+ print(f"Error during image analysis: {e}")
243
+ return None
244
 
245
  content = response.content.strip()
246
  last_line = ""
hf_chat.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Chat Model Wrapper for vision models like Qwen2-VL
3
+ """
4
+
5
+ import os
6
+ import base64
7
+ import requests
8
+ from typing import List, Dict, Any, Optional
9
+ from langchain_core.messages import BaseMessage, HumanMessage
10
+ from langchain_core.language_models.chat_models import BaseChatModel
11
+ from langchain_core.outputs import ChatResult, ChatGeneration
12
+ from pydantic import Field
13
+
14
+
15
+ class HuggingFaceChat(BaseChatModel):
16
+ """Chat model wrapper for HuggingFace Inference API"""
17
+
18
+ model: str = Field(description="HuggingFace model name")
19
+ temperature: float = Field(default=0.0, description="Temperature for sampling")
20
+ max_tokens: int = Field(default=1000, description="Max tokens to generate")
21
+ api_token: Optional[str] = Field(default=None, description="HF API token")
22
+
23
+ def __init__(self, model: str, temperature: float = 0.0, **kwargs):
24
+ api_token = kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY")
25
+ if not api_token:
26
+ raise ValueError("HUGGINGFACE_API_KEY environment variable is required")
27
+
28
+ super().__init__(
29
+ model=model, temperature=temperature, api_token=api_token, **kwargs
30
+ )
31
+
32
+ @property
33
+ def _llm_type(self) -> str:
34
+ return "huggingface_chat"
35
+
36
+ def _format_message_for_hf(self, message: HumanMessage) -> Dict[str, Any]:
37
+ """Convert LangChain message to HuggingFace format"""
38
+ if isinstance(message.content, str):
39
+ return {"role": "user", "content": message.content}
40
+
41
+ # Handle multi-modal content (text + images)
42
+ formatted_content = []
43
+ for item in message.content:
44
+ if item["type"] == "text":
45
+ formatted_content.append({"type": "text", "text": item["text"]})
46
+ elif item["type"] == "image_url":
47
+ # Extract base64 data from data URL
48
+ image_url = item["image_url"]["url"]
49
+ if image_url.startswith("data:image"):
50
+ # Extract base64 data
51
+ base64_data = image_url.split(",")[1]
52
+ formatted_content.append({"type": "image", "image": base64_data})
53
+
54
+ return {"role": "user", "content": formatted_content}
55
+
56
+ def _generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
57
+ """Generate response using HuggingFace Inference API"""
58
+
59
+ # Format messages for HF API
60
+ formatted_messages = []
61
+ for msg in messages:
62
+ if isinstance(msg, HumanMessage):
63
+ formatted_messages.append(self._format_message_for_hf(msg))
64
+
65
+ # Prepare API request
66
+ api_url = f"https://api-inference.huggingface.co/models/{self.model}/v1/chat/completions"
67
+ headers = {
68
+ "Authorization": f"Bearer {self.api_token}",
69
+ "Content-Type": "application/json",
70
+ }
71
+
72
+ payload = {
73
+ "model": self.model,
74
+ "messages": formatted_messages,
75
+ "temperature": self.temperature,
76
+ "max_tokens": self.max_tokens,
77
+ "stream": False,
78
+ }
79
+
80
+ try:
81
+ response = requests.post(api_url, headers=headers, json=payload, timeout=60)
82
+ response.raise_for_status()
83
+
84
+ result = response.json()
85
+ content = result["choices"][0]["message"]["content"]
86
+
87
+ return ChatResult(
88
+ generations=[ChatGeneration(message=HumanMessage(content=content))]
89
+ )
90
+
91
+ except requests.exceptions.RequestException as e:
92
+ # Fallback to simple text-only API if chat completions fail
93
+ return self._fallback_generate(messages, **kwargs)
94
+
95
+ def _fallback_generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
96
+ """Fallback to simple HF Inference API"""
97
+ try:
98
+ # Use simple inference API as fallback
99
+ api_url = f"https://api-inference.huggingface.co/models/{self.model}"
100
+ headers = {
101
+ "Authorization": f"Bearer {self.api_token}",
102
+ "Content-Type": "application/json",
103
+ }
104
+
105
+ # Extract text content only for fallback
106
+ text_content = ""
107
+ for msg in messages:
108
+ if isinstance(msg, HumanMessage):
109
+ if isinstance(msg.content, str):
110
+ text_content += msg.content
111
+ else:
112
+ for item in msg.content:
113
+ if item["type"] == "text":
114
+ text_content += item["text"] + "\n"
115
+
116
+ payload = {
117
+ "inputs": text_content,
118
+ "parameters": {
119
+ "temperature": self.temperature,
120
+ "max_new_tokens": self.max_tokens,
121
+ },
122
+ }
123
+
124
+ response = requests.post(api_url, headers=headers, json=payload, timeout=60)
125
+ response.raise_for_status()
126
+
127
+ result = response.json()
128
+ if isinstance(result, list) and len(result) > 0:
129
+ content = result[0].get("generated_text", "No response generated")
130
+ else:
131
+ content = "Error: Invalid response format"
132
+
133
+ return ChatResult(
134
+ generations=[ChatGeneration(message=HumanMessage(content=content))]
135
+ )
136
+
137
+ except Exception as e:
138
+ # Last resort fallback
139
+ error_msg = f"HuggingFace API Error: {str(e)}. Please check your API key and model availability."
140
+ return ChatResult(
141
+ generations=[ChatGeneration(message=HumanMessage(content=error_msg))]
142
+ )