Spaces:
Building
Building
Andy Lee
commited on
Commit
·
4d37e51
1
Parent(s):
23ee129
feat: more models, including qwen
Browse files- app.py +104 -25
- config.py +26 -0
- geo_bot.py +32 -12
- hf_chat.py +142 -0
app.py
CHANGED
@@ -136,63 +136,138 @@ if start_button:
|
|
136 |
# --- Inner agent exploration loop ---
|
137 |
history = []
|
138 |
final_guess = None
|
|
|
139 |
|
140 |
for step in range(steps_per_sample):
|
141 |
step_num = step + 1
|
142 |
reasoning_placeholder.info(
|
143 |
-
f"Thinking... (Step {step_num}/{steps_per_sample})"
|
144 |
)
|
145 |
action_placeholder.empty()
|
146 |
|
147 |
# Observe and label arrows
|
148 |
bot.controller.label_arrows_on_screen()
|
149 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
|
|
|
|
150 |
image_placeholder.image(
|
151 |
-
screenshot_bytes,
|
|
|
|
|
152 |
)
|
153 |
|
154 |
# Update history
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
)
|
163 |
|
164 |
# Think
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
166 |
remaining_steps=steps_per_sample - step,
|
167 |
-
history_text=
|
168 |
-
|
169 |
-
),
|
170 |
-
available_actions=json.dumps(bot.controller.get_available_actions()),
|
171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
message = bot._create_message_with_history(
|
173 |
prompt, [h["image_b64"] for h in history]
|
174 |
)
|
|
|
|
|
175 |
response = bot.model.invoke(message)
|
176 |
decision = bot._parse_agent_response(response)
|
177 |
|
178 |
if not decision: # Fallback
|
179 |
decision = {
|
180 |
"action_details": {"action": "PAN_RIGHT"},
|
181 |
-
"reasoning": "
|
182 |
}
|
183 |
|
184 |
action = decision.get("action_details", {}).get("action")
|
185 |
history[-1]["action"] = action
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
)
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# Force a GUESS on the last step
|
193 |
if step_num == steps_per_sample and action != "GUESS":
|
194 |
-
st.warning("Max steps reached. Forcing a GUESS action.")
|
195 |
action = "GUESS"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
# Act
|
198 |
if action == "GUESS":
|
@@ -204,18 +279,22 @@ if start_button:
|
|
204 |
final_guess = (lat, lon)
|
205 |
else:
|
206 |
st.error(
|
207 |
-
"GUESS action was missing coordinates. Guess failed for this sample."
|
208 |
)
|
209 |
break # End exploration for the current sample
|
210 |
|
211 |
elif action == "MOVE_FORWARD":
|
212 |
-
|
|
|
213 |
elif action == "MOVE_BACKWARD":
|
214 |
-
|
|
|
215 |
elif action == "PAN_LEFT":
|
216 |
-
|
|
|
217 |
elif action == "PAN_RIGHT":
|
218 |
-
|
|
|
219 |
|
220 |
time.sleep(1) # A brief pause between steps for better visualization
|
221 |
|
|
|
136 |
# --- Inner agent exploration loop ---
|
137 |
history = []
|
138 |
final_guess = None
|
139 |
+
step_history_container = st.container()
|
140 |
|
141 |
for step in range(steps_per_sample):
|
142 |
step_num = step + 1
|
143 |
reasoning_placeholder.info(
|
144 |
+
f"🤔 Thinking... (Step {step_num}/{steps_per_sample})"
|
145 |
)
|
146 |
action_placeholder.empty()
|
147 |
|
148 |
# Observe and label arrows
|
149 |
bot.controller.label_arrows_on_screen()
|
150 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
151 |
+
|
152 |
+
# Current view
|
153 |
image_placeholder.image(
|
154 |
+
screenshot_bytes,
|
155 |
+
caption=f"🔍 Step {step_num} - What AI Sees Now",
|
156 |
+
use_column_width=True,
|
157 |
)
|
158 |
|
159 |
# Update history
|
160 |
+
current_step_data = {
|
161 |
+
"image_b64": bot.pil_to_base64(Image.open(BytesIO(screenshot_bytes))),
|
162 |
+
"action": "N/A",
|
163 |
+
"screenshot_bytes": screenshot_bytes,
|
164 |
+
"step_num": step_num,
|
165 |
+
}
|
166 |
+
history.append(current_step_data)
|
|
|
167 |
|
168 |
# Think
|
169 |
+
available_actions = bot.controller.get_available_actions()
|
170 |
+
history_text = "\n".join(
|
171 |
+
[f"Step {j + 1}: {h['action']}" for j, h in enumerate(history[:-1])]
|
172 |
+
)
|
173 |
+
if not history_text:
|
174 |
+
history_text = "No history yet. This is the first step."
|
175 |
+
|
176 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
177 |
remaining_steps=steps_per_sample - step,
|
178 |
+
history_text=history_text,
|
179 |
+
available_actions=json.dumps(available_actions),
|
|
|
|
|
180 |
)
|
181 |
+
|
182 |
+
# Show what AI is considering
|
183 |
+
with reasoning_placeholder:
|
184 |
+
st.info("🧠 **AI is analyzing the situation...**")
|
185 |
+
with st.expander("🔍 Available Actions", expanded=False):
|
186 |
+
st.json(available_actions)
|
187 |
+
with st.expander("📝 Context Being Considered", expanded=False):
|
188 |
+
st.text_area(
|
189 |
+
"History Context:", history_text, height=100, disabled=True
|
190 |
+
)
|
191 |
+
|
192 |
message = bot._create_message_with_history(
|
193 |
prompt, [h["image_b64"] for h in history]
|
194 |
)
|
195 |
+
|
196 |
+
# Get AI response
|
197 |
response = bot.model.invoke(message)
|
198 |
decision = bot._parse_agent_response(response)
|
199 |
|
200 |
if not decision: # Fallback
|
201 |
decision = {
|
202 |
"action_details": {"action": "PAN_RIGHT"},
|
203 |
+
"reasoning": "⚠️ Response parsing failed. Using default recovery action.",
|
204 |
}
|
205 |
|
206 |
action = decision.get("action_details", {}).get("action")
|
207 |
history[-1]["action"] = action
|
208 |
+
history[-1]["reasoning"] = decision.get("reasoning", "N/A")
|
209 |
+
history[-1]["raw_response"] = (
|
210 |
+
response.content[:500] + "..."
|
211 |
+
if len(response.content) > 500
|
212 |
+
else response.content
|
213 |
)
|
214 |
+
|
215 |
+
# Display AI's decision process
|
216 |
+
reasoning_placeholder.success("✅ **AI Decision Made!**")
|
217 |
+
|
218 |
+
with action_placeholder:
|
219 |
+
st.success(f"🎯 **AI Action:** `{action}`")
|
220 |
+
|
221 |
+
# Detailed reasoning display
|
222 |
+
with st.expander("🧠 AI's Detailed Thinking Process", expanded=True):
|
223 |
+
col_reason, col_raw = st.columns([2, 1])
|
224 |
+
|
225 |
+
with col_reason:
|
226 |
+
st.markdown("**🤔 AI's Reasoning:**")
|
227 |
+
st.info(decision.get("reasoning", "N/A"))
|
228 |
+
|
229 |
+
if action == "GUESS":
|
230 |
+
lat = decision.get("action_details", {}).get("lat")
|
231 |
+
lon = decision.get("action_details", {}).get("lon")
|
232 |
+
if lat and lon:
|
233 |
+
st.success(f"📍 **Final Guess:** {lat:.4f}, {lon:.4f}")
|
234 |
+
|
235 |
+
with col_raw:
|
236 |
+
st.markdown("**🔤 Raw AI Response:**")
|
237 |
+
st.text_area(
|
238 |
+
"Full Response:",
|
239 |
+
history[-1]["raw_response"],
|
240 |
+
height=200,
|
241 |
+
disabled=True,
|
242 |
+
key=f"raw_response_{step_num}",
|
243 |
+
)
|
244 |
+
|
245 |
+
# Store step in history display
|
246 |
+
with step_history_container:
|
247 |
+
with st.expander(f"📚 Step {step_num} History", expanded=False):
|
248 |
+
hist_col1, hist_col2 = st.columns([1, 2])
|
249 |
+
with hist_col1:
|
250 |
+
st.image(
|
251 |
+
screenshot_bytes, caption=f"Step {step_num} View", width=200
|
252 |
+
)
|
253 |
+
with hist_col2:
|
254 |
+
st.write(f"**Action:** {action}")
|
255 |
+
st.write(
|
256 |
+
f"**Reasoning:** {decision.get('reasoning', 'N/A')[:150]}..."
|
257 |
+
)
|
258 |
|
259 |
# Force a GUESS on the last step
|
260 |
if step_num == steps_per_sample and action != "GUESS":
|
261 |
+
st.warning("⏰ Max steps reached. Forcing a GUESS action.")
|
262 |
action = "GUESS"
|
263 |
+
# Force coordinates if missing
|
264 |
+
if not decision.get("action_details", {}).get("lat"):
|
265 |
+
st.error("❌ AI didn't provide coordinates. Using fallback guess.")
|
266 |
+
decision["action_details"] = {
|
267 |
+
"action": "GUESS",
|
268 |
+
"lat": 0.0,
|
269 |
+
"lon": 0.0,
|
270 |
+
}
|
271 |
|
272 |
# Act
|
273 |
if action == "GUESS":
|
|
|
279 |
final_guess = (lat, lon)
|
280 |
else:
|
281 |
st.error(
|
282 |
+
"❌ GUESS action was missing coordinates. Guess failed for this sample."
|
283 |
)
|
284 |
break # End exploration for the current sample
|
285 |
|
286 |
elif action == "MOVE_FORWARD":
|
287 |
+
with st.spinner("🚶 Moving forward..."):
|
288 |
+
bot.controller.move("forward")
|
289 |
elif action == "MOVE_BACKWARD":
|
290 |
+
with st.spinner("🔄 Moving backward..."):
|
291 |
+
bot.controller.move("backward")
|
292 |
elif action == "PAN_LEFT":
|
293 |
+
with st.spinner("⬅️ Panning left..."):
|
294 |
+
bot.controller.pan_view("left")
|
295 |
elif action == "PAN_RIGHT":
|
296 |
+
with st.spinner("➡️ Panning right..."):
|
297 |
+
bot.controller.pan_view("right")
|
298 |
|
299 |
time.sleep(1) # A brief pause between steps for better visualization
|
300 |
|
config.py
CHANGED
@@ -31,18 +31,44 @@ MODELS_CONFIG = {
|
|
31 |
"gpt-4o": {
|
32 |
"class": "ChatOpenAI",
|
33 |
"model_name": "gpt-4o",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
},
|
35 |
"claude-3.5-sonnet": {
|
36 |
"class": "ChatAnthropic",
|
37 |
"model_name": "claude-3-5-sonnet-20240620",
|
|
|
|
|
38 |
},
|
39 |
"gemini-1.5-pro": {
|
40 |
"class": "ChatGoogleGenerativeAI",
|
41 |
"model_name": "gemini-1.5-pro-latest",
|
|
|
|
|
42 |
},
|
43 |
"gemini-2.5-pro": {
|
44 |
"class": "ChatGoogleGenerativeAI",
|
45 |
"model_name": "gemini-2.5-pro-preview-06-05",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
},
|
47 |
}
|
48 |
|
|
|
31 |
"gpt-4o": {
|
32 |
"class": "ChatOpenAI",
|
33 |
"model_name": "gpt-4o",
|
34 |
+
"api_key_env": "OPENAI_API_KEY",
|
35 |
+
"description": "OpenAI GPT-4o",
|
36 |
+
},
|
37 |
+
"gpt-4o-mini": {
|
38 |
+
"class": "ChatOpenAI",
|
39 |
+
"model_name": "gpt-4o-mini",
|
40 |
+
"api_key_env": "OPENAI_API_KEY",
|
41 |
+
"description": "OpenAI GPT-4o Mini (cheaper)",
|
42 |
},
|
43 |
"claude-3.5-sonnet": {
|
44 |
"class": "ChatAnthropic",
|
45 |
"model_name": "claude-3-5-sonnet-20240620",
|
46 |
+
"api_key_env": "ANTHROPIC_API_KEY",
|
47 |
+
"description": "Anthropic Claude 3.5 Sonnet",
|
48 |
},
|
49 |
"gemini-1.5-pro": {
|
50 |
"class": "ChatGoogleGenerativeAI",
|
51 |
"model_name": "gemini-1.5-pro-latest",
|
52 |
+
"api_key_env": "GOOGLE_API_KEY",
|
53 |
+
"description": "Google Gemini 1.5 Pro",
|
54 |
},
|
55 |
"gemini-2.5-pro": {
|
56 |
"class": "ChatGoogleGenerativeAI",
|
57 |
"model_name": "gemini-2.5-pro-preview-06-05",
|
58 |
+
"api_key_env": "GOOGLE_API_KEY",
|
59 |
+
"description": "Google Gemini 2.5 Pro",
|
60 |
+
},
|
61 |
+
"qwen2-vl-72b": {
|
62 |
+
"class": "HuggingFaceChat",
|
63 |
+
"model_name": "Qwen/Qwen2-VL-72B-Instruct",
|
64 |
+
"api_key_env": "HUGGINGFACE_API_KEY",
|
65 |
+
"description": "Qwen2-VL 72B (via HF Inference API)",
|
66 |
+
},
|
67 |
+
"qwen2-vl-7b": {
|
68 |
+
"class": "HuggingFaceChat",
|
69 |
+
"model_name": "Qwen/Qwen2-VL-7B-Instruct",
|
70 |
+
"api_key_env": "HUGGINGFACE_API_KEY",
|
71 |
+
"description": "Qwen2-VL 7B (via HF Inference API)",
|
72 |
},
|
73 |
}
|
74 |
|
geo_bot.py
CHANGED
@@ -11,10 +11,11 @@ from langchain_openai import ChatOpenAI
|
|
11 |
from langchain_anthropic import ChatAnthropic
|
12 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
|
|
|
|
|
14 |
from mapcrunch_controller import MapCrunchController
|
15 |
|
16 |
# The "Golden" Prompt (v6): Combines clear mechanics with robust strategic principles.
|
17 |
-
|
18 |
AGENT_PROMPT_TEMPLATE = """
|
19 |
**Mission:** You are an expert geo-location agent. Your goal is to find clues to determine your location within a limited number of steps.
|
20 |
|
@@ -68,11 +69,20 @@ class GeoBot:
|
|
68 |
):
|
69 |
# Initialize model with temperature parameter
|
70 |
model_kwargs = {
|
71 |
-
"model": model_name,
|
72 |
"temperature": temperature,
|
73 |
}
|
74 |
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
self.model_name = model_name
|
77 |
self.temperature = temperature
|
78 |
self.use_selenium = use_selenium
|
@@ -90,6 +100,7 @@ class GeoBot:
|
|
90 |
) -> List[HumanMessage]:
|
91 |
"""Creates a message for the LLM that includes text and a sequence of images."""
|
92 |
content = [{"type": "text", "text": prompt}]
|
|
|
93 |
# Add the JSON format instructions right after the main prompt text
|
94 |
content.append(
|
95 |
{
|
@@ -145,7 +156,6 @@ class GeoBot:
|
|
145 |
print(f"\n--- Step {max_steps - step + 1}/{max_steps} ---")
|
146 |
|
147 |
self.controller.setup_clean_environment()
|
148 |
-
|
149 |
self.controller.label_arrows_on_screen()
|
150 |
|
151 |
screenshot_bytes = self.controller.take_street_view_screenshot()
|
@@ -178,17 +188,22 @@ class GeoBot:
|
|
178 |
available_actions=json.dumps(available_actions),
|
179 |
)
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
if not decision:
|
187 |
print(
|
188 |
-
"Response parsing failed. Using default recovery action: PAN_RIGHT."
|
189 |
)
|
190 |
decision = {
|
191 |
-
"reasoning": "Recovery due to parsing failure.",
|
192 |
"action_details": {"action": "PAN_RIGHT"},
|
193 |
}
|
194 |
|
@@ -219,8 +234,13 @@ class GeoBot:
|
|
219 |
def analyze_image(self, image: Image.Image) -> Optional[Tuple[float, float]]:
|
220 |
image_b64 = self.pil_to_base64(image)
|
221 |
message = self._create_llm_message(BENCHMARK_PROMPT, image_b64)
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
content = response.content.strip()
|
226 |
last_line = ""
|
|
|
11 |
from langchain_anthropic import ChatAnthropic
|
12 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
13 |
|
14 |
+
from hf_chat import HuggingFaceChat
|
15 |
+
|
16 |
from mapcrunch_controller import MapCrunchController
|
17 |
|
18 |
# The "Golden" Prompt (v6): Combines clear mechanics with robust strategic principles.
|
|
|
19 |
AGENT_PROMPT_TEMPLATE = """
|
20 |
**Mission:** You are an expert geo-location agent. Your goal is to find clues to determine your location within a limited number of steps.
|
21 |
|
|
|
69 |
):
|
70 |
# Initialize model with temperature parameter
|
71 |
model_kwargs = {
|
|
|
72 |
"temperature": temperature,
|
73 |
}
|
74 |
|
75 |
+
# Handle different model types
|
76 |
+
if model == HuggingFaceChat and HuggingFaceChat is not None:
|
77 |
+
model_kwargs["model"] = model_name
|
78 |
+
else:
|
79 |
+
model_kwargs["model"] = model_name
|
80 |
+
|
81 |
+
try:
|
82 |
+
self.model = model(**model_kwargs)
|
83 |
+
except Exception as e:
|
84 |
+
raise ValueError(f"Failed to initialize model {model_name}: {e}")
|
85 |
+
|
86 |
self.model_name = model_name
|
87 |
self.temperature = temperature
|
88 |
self.use_selenium = use_selenium
|
|
|
100 |
) -> List[HumanMessage]:
|
101 |
"""Creates a message for the LLM that includes text and a sequence of images."""
|
102 |
content = [{"type": "text", "text": prompt}]
|
103 |
+
|
104 |
# Add the JSON format instructions right after the main prompt text
|
105 |
content.append(
|
106 |
{
|
|
|
156 |
print(f"\n--- Step {max_steps - step + 1}/{max_steps} ---")
|
157 |
|
158 |
self.controller.setup_clean_environment()
|
|
|
159 |
self.controller.label_arrows_on_screen()
|
160 |
|
161 |
screenshot_bytes = self.controller.take_street_view_screenshot()
|
|
|
188 |
available_actions=json.dumps(available_actions),
|
189 |
)
|
190 |
|
191 |
+
try:
|
192 |
+
message = self._create_message_with_history(
|
193 |
+
prompt, image_b64_for_prompt
|
194 |
+
)
|
195 |
+
response = self.model.invoke(message)
|
196 |
+
decision = self._parse_agent_response(response)
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error during model invocation: {e}")
|
199 |
+
decision = None
|
200 |
|
201 |
if not decision:
|
202 |
print(
|
203 |
+
"Response parsing failed or model error. Using default recovery action: PAN_RIGHT."
|
204 |
)
|
205 |
decision = {
|
206 |
+
"reasoning": "Recovery due to parsing failure or model error.",
|
207 |
"action_details": {"action": "PAN_RIGHT"},
|
208 |
}
|
209 |
|
|
|
234 |
def analyze_image(self, image: Image.Image) -> Optional[Tuple[float, float]]:
|
235 |
image_b64 = self.pil_to_base64(image)
|
236 |
message = self._create_llm_message(BENCHMARK_PROMPT, image_b64)
|
237 |
+
|
238 |
+
try:
|
239 |
+
response = self.model.invoke(message)
|
240 |
+
print(f"\nLLM Response:\n{response.content}")
|
241 |
+
except Exception as e:
|
242 |
+
print(f"Error during image analysis: {e}")
|
243 |
+
return None
|
244 |
|
245 |
content = response.content.strip()
|
246 |
last_line = ""
|
hf_chat.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
HuggingFace Chat Model Wrapper for vision models like Qwen2-VL
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
import requests
|
8 |
+
from typing import List, Dict, Any, Optional
|
9 |
+
from langchain_core.messages import BaseMessage, HumanMessage
|
10 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
11 |
+
from langchain_core.outputs import ChatResult, ChatGeneration
|
12 |
+
from pydantic import Field
|
13 |
+
|
14 |
+
|
15 |
+
class HuggingFaceChat(BaseChatModel):
|
16 |
+
"""Chat model wrapper for HuggingFace Inference API"""
|
17 |
+
|
18 |
+
model: str = Field(description="HuggingFace model name")
|
19 |
+
temperature: float = Field(default=0.0, description="Temperature for sampling")
|
20 |
+
max_tokens: int = Field(default=1000, description="Max tokens to generate")
|
21 |
+
api_token: Optional[str] = Field(default=None, description="HF API token")
|
22 |
+
|
23 |
+
def __init__(self, model: str, temperature: float = 0.0, **kwargs):
|
24 |
+
api_token = kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY")
|
25 |
+
if not api_token:
|
26 |
+
raise ValueError("HUGGINGFACE_API_KEY environment variable is required")
|
27 |
+
|
28 |
+
super().__init__(
|
29 |
+
model=model, temperature=temperature, api_token=api_token, **kwargs
|
30 |
+
)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def _llm_type(self) -> str:
|
34 |
+
return "huggingface_chat"
|
35 |
+
|
36 |
+
def _format_message_for_hf(self, message: HumanMessage) -> Dict[str, Any]:
|
37 |
+
"""Convert LangChain message to HuggingFace format"""
|
38 |
+
if isinstance(message.content, str):
|
39 |
+
return {"role": "user", "content": message.content}
|
40 |
+
|
41 |
+
# Handle multi-modal content (text + images)
|
42 |
+
formatted_content = []
|
43 |
+
for item in message.content:
|
44 |
+
if item["type"] == "text":
|
45 |
+
formatted_content.append({"type": "text", "text": item["text"]})
|
46 |
+
elif item["type"] == "image_url":
|
47 |
+
# Extract base64 data from data URL
|
48 |
+
image_url = item["image_url"]["url"]
|
49 |
+
if image_url.startswith("data:image"):
|
50 |
+
# Extract base64 data
|
51 |
+
base64_data = image_url.split(",")[1]
|
52 |
+
formatted_content.append({"type": "image", "image": base64_data})
|
53 |
+
|
54 |
+
return {"role": "user", "content": formatted_content}
|
55 |
+
|
56 |
+
def _generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
|
57 |
+
"""Generate response using HuggingFace Inference API"""
|
58 |
+
|
59 |
+
# Format messages for HF API
|
60 |
+
formatted_messages = []
|
61 |
+
for msg in messages:
|
62 |
+
if isinstance(msg, HumanMessage):
|
63 |
+
formatted_messages.append(self._format_message_for_hf(msg))
|
64 |
+
|
65 |
+
# Prepare API request
|
66 |
+
api_url = f"https://api-inference.huggingface.co/models/{self.model}/v1/chat/completions"
|
67 |
+
headers = {
|
68 |
+
"Authorization": f"Bearer {self.api_token}",
|
69 |
+
"Content-Type": "application/json",
|
70 |
+
}
|
71 |
+
|
72 |
+
payload = {
|
73 |
+
"model": self.model,
|
74 |
+
"messages": formatted_messages,
|
75 |
+
"temperature": self.temperature,
|
76 |
+
"max_tokens": self.max_tokens,
|
77 |
+
"stream": False,
|
78 |
+
}
|
79 |
+
|
80 |
+
try:
|
81 |
+
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
|
82 |
+
response.raise_for_status()
|
83 |
+
|
84 |
+
result = response.json()
|
85 |
+
content = result["choices"][0]["message"]["content"]
|
86 |
+
|
87 |
+
return ChatResult(
|
88 |
+
generations=[ChatGeneration(message=HumanMessage(content=content))]
|
89 |
+
)
|
90 |
+
|
91 |
+
except requests.exceptions.RequestException as e:
|
92 |
+
# Fallback to simple text-only API if chat completions fail
|
93 |
+
return self._fallback_generate(messages, **kwargs)
|
94 |
+
|
95 |
+
def _fallback_generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
|
96 |
+
"""Fallback to simple HF Inference API"""
|
97 |
+
try:
|
98 |
+
# Use simple inference API as fallback
|
99 |
+
api_url = f"https://api-inference.huggingface.co/models/{self.model}"
|
100 |
+
headers = {
|
101 |
+
"Authorization": f"Bearer {self.api_token}",
|
102 |
+
"Content-Type": "application/json",
|
103 |
+
}
|
104 |
+
|
105 |
+
# Extract text content only for fallback
|
106 |
+
text_content = ""
|
107 |
+
for msg in messages:
|
108 |
+
if isinstance(msg, HumanMessage):
|
109 |
+
if isinstance(msg.content, str):
|
110 |
+
text_content += msg.content
|
111 |
+
else:
|
112 |
+
for item in msg.content:
|
113 |
+
if item["type"] == "text":
|
114 |
+
text_content += item["text"] + "\n"
|
115 |
+
|
116 |
+
payload = {
|
117 |
+
"inputs": text_content,
|
118 |
+
"parameters": {
|
119 |
+
"temperature": self.temperature,
|
120 |
+
"max_new_tokens": self.max_tokens,
|
121 |
+
},
|
122 |
+
}
|
123 |
+
|
124 |
+
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
|
125 |
+
response.raise_for_status()
|
126 |
+
|
127 |
+
result = response.json()
|
128 |
+
if isinstance(result, list) and len(result) > 0:
|
129 |
+
content = result[0].get("generated_text", "No response generated")
|
130 |
+
else:
|
131 |
+
content = "Error: Invalid response format"
|
132 |
+
|
133 |
+
return ChatResult(
|
134 |
+
generations=[ChatGeneration(message=HumanMessage(content=content))]
|
135 |
+
)
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
# Last resort fallback
|
139 |
+
error_msg = f"HuggingFace API Error: {str(e)}. Please check your API key and model availability."
|
140 |
+
return ChatResult(
|
141 |
+
generations=[ChatGeneration(message=HumanMessage(content=error_msg))]
|
142 |
+
)
|