Spaces:
Running
Running
Update and format for review commits
Browse filesSigned-off-by: Sairam Pillai <[email protected]>
- app.py +4 -12
- helpers/llm_helper.py +4 -2
app.py
CHANGED
|
@@ -195,31 +195,23 @@ with st.sidebar:
|
|
| 195 |
api_version: str = ''
|
| 196 |
else:
|
| 197 |
# The online LLMs
|
| 198 |
-
|
| 199 |
label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
|
| 200 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 201 |
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
| 202 |
help=GlobalConfig.LLM_PROVIDER_HELP,
|
| 203 |
on_change=reset_api_key
|
| 204 |
-
)
|
| 205 |
|
| 206 |
-
# Extract provider key more robustly using regex
|
| 207 |
-
provider_match = GlobalConfig.PROVIDER_REGEX.match(selected_option)
|
| 208 |
-
if provider_match:
|
| 209 |
-
llm_provider_to_use = selected_option # Use full string for get_provider_model
|
| 210 |
-
else:
|
| 211 |
-
# Fallback: try to extract the key before the first space
|
| 212 |
-
llm_provider_to_use = selected_option.split(' ')[0]
|
| 213 |
-
logger.warning(f"Could not parse provider from selectbox option: {selected_option}")
|
| 214 |
-
|
| 215 |
# --- Automatically fetch API key from .env if available ---
|
|
|
|
| 216 |
provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
|
| 217 |
if provider_match:
|
| 218 |
selected_provider = provider_match.group(1)
|
| 219 |
else:
|
| 220 |
# If regex doesn't match, try to extract provider from the beginning
|
| 221 |
selected_provider = llm_provider_to_use.split(' ')[0] if ' ' in llm_provider_to_use else llm_provider_to_use
|
| 222 |
-
logger.warning(
|
| 223 |
|
| 224 |
# Validate that the selected provider is valid
|
| 225 |
if selected_provider not in GlobalConfig.VALID_PROVIDERS:
|
|
|
|
| 195 |
api_version: str = ''
|
| 196 |
else:
|
| 197 |
# The online LLMs
|
| 198 |
+
llm_provider_to_use = st.sidebar.selectbox(
|
| 199 |
label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
|
| 200 |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
|
| 201 |
index=GlobalConfig.DEFAULT_MODEL_INDEX,
|
| 202 |
help=GlobalConfig.LLM_PROVIDER_HELP,
|
| 203 |
on_change=reset_api_key
|
| 204 |
+
).split(' ')[0]
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
# --- Automatically fetch API key from .env if available ---
|
| 207 |
+
# Extract provider key using regex
|
| 208 |
provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
|
| 209 |
if provider_match:
|
| 210 |
selected_provider = provider_match.group(1)
|
| 211 |
else:
|
| 212 |
# If regex doesn't match, try to extract provider from the beginning
|
| 213 |
selected_provider = llm_provider_to_use.split(' ')[0] if ' ' in llm_provider_to_use else llm_provider_to_use
|
| 214 |
+
logger.warning("Provider regex did not match for: %s, using: %s", llm_provider_to_use, selected_provider)
|
| 215 |
|
| 216 |
# Validate that the selected provider is valid
|
| 217 |
if selected_provider not in GlobalConfig.VALID_PROVIDERS:
|
helpers/llm_helper.py
CHANGED
|
@@ -5,7 +5,7 @@ import logging
|
|
| 5 |
import re
|
| 6 |
import sys
|
| 7 |
import urllib3
|
| 8 |
-
from typing import Tuple, Union, Iterator
|
| 9 |
|
| 10 |
import requests
|
| 11 |
import os
|
|
@@ -121,7 +121,7 @@ def is_valid_llm_provider_model(
|
|
| 121 |
return True
|
| 122 |
|
| 123 |
|
| 124 |
-
def get_litellm_model_name(provider: str, model: str) -> str:
|
| 125 |
"""
|
| 126 |
Convert provider and model to LiteLLM model name format.
|
| 127 |
|
|
@@ -184,6 +184,8 @@ def stream_litellm_completion(
|
|
| 184 |
litellm_model = f'azure/{azure_deployment_name}'
|
| 185 |
else:
|
| 186 |
litellm_model = get_litellm_model_name(provider, model)
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# Prepare the request parameters
|
| 189 |
request_params = {
|
|
|
|
| 5 |
import re
|
| 6 |
import sys
|
| 7 |
import urllib3
|
| 8 |
+
from typing import Tuple, Union, Iterator, Optional
|
| 9 |
|
| 10 |
import requests
|
| 11 |
import os
|
|
|
|
| 121 |
return True
|
| 122 |
|
| 123 |
|
| 124 |
+
def get_litellm_model_name(provider: str, model: str) -> Optional[str]:
|
| 125 |
"""
|
| 126 |
Convert provider and model to LiteLLM model name format.
|
| 127 |
|
|
|
|
| 184 |
litellm_model = f'azure/{azure_deployment_name}'
|
| 185 |
else:
|
| 186 |
litellm_model = get_litellm_model_name(provider, model)
|
| 187 |
+
if not litellm_model:
|
| 188 |
+
raise ValueError(f"Invalid model name: {model} for provider: {provider}")
|
| 189 |
|
| 190 |
# Prepare the request parameters
|
| 191 |
request_params = {
|