sairampillai commited on
Commit
09eecef
·
unverified ·
1 Parent(s): 80f53c9

Update and format for review commits

Browse files

Signed-off-by: Sairam Pillai <[email protected]>

Files changed (2) hide show
  1. app.py +4 -12
  2. helpers/llm_helper.py +4 -2
app.py CHANGED
@@ -195,31 +195,23 @@ with st.sidebar:
195
  api_version: str = ''
196
  else:
197
  # The online LLMs
198
- selected_option = st.sidebar.selectbox(
199
  label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
200
  options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
201
  index=GlobalConfig.DEFAULT_MODEL_INDEX,
202
  help=GlobalConfig.LLM_PROVIDER_HELP,
203
  on_change=reset_api_key
204
- )
205
 
206
- # Extract provider key more robustly using regex
207
- provider_match = GlobalConfig.PROVIDER_REGEX.match(selected_option)
208
- if provider_match:
209
- llm_provider_to_use = selected_option # Use full string for get_provider_model
210
- else:
211
- # Fallback: try to extract the key before the first space
212
- llm_provider_to_use = selected_option.split(' ')[0]
213
- logger.warning(f"Could not parse provider from selectbox option: {selected_option}")
214
-
215
  # --- Automatically fetch API key from .env if available ---
 
216
  provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
217
  if provider_match:
218
  selected_provider = provider_match.group(1)
219
  else:
220
  # If regex doesn't match, try to extract provider from the beginning
221
  selected_provider = llm_provider_to_use.split(' ')[0] if ' ' in llm_provider_to_use else llm_provider_to_use
222
- logger.warning(f"Provider regex did not match for: {llm_provider_to_use}, using: {selected_provider}")
223
 
224
  # Validate that the selected provider is valid
225
  if selected_provider not in GlobalConfig.VALID_PROVIDERS:
 
195
  api_version: str = ''
196
  else:
197
  # The online LLMs
198
+ llm_provider_to_use = st.sidebar.selectbox(
199
  label='2: Select a suitable LLM to use:\n\n(Gemini and Mistral-Nemo are recommended)',
200
  options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
201
  index=GlobalConfig.DEFAULT_MODEL_INDEX,
202
  help=GlobalConfig.LLM_PROVIDER_HELP,
203
  on_change=reset_api_key
204
+ ).split(' ')[0]
205
 
 
 
 
 
 
 
 
 
 
206
  # --- Automatically fetch API key from .env if available ---
207
+ # Extract provider key using regex
208
  provider_match = GlobalConfig.PROVIDER_REGEX.match(llm_provider_to_use)
209
  if provider_match:
210
  selected_provider = provider_match.group(1)
211
  else:
212
  # If regex doesn't match, try to extract provider from the beginning
213
  selected_provider = llm_provider_to_use.split(' ')[0] if ' ' in llm_provider_to_use else llm_provider_to_use
214
+ logger.warning("Provider regex did not match for: %s, using: %s", llm_provider_to_use, selected_provider)
215
 
216
  # Validate that the selected provider is valid
217
  if selected_provider not in GlobalConfig.VALID_PROVIDERS:
helpers/llm_helper.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
  import re
6
  import sys
7
  import urllib3
8
- from typing import Tuple, Union, Iterator
9
 
10
  import requests
11
  import os
@@ -121,7 +121,7 @@ def is_valid_llm_provider_model(
121
  return True
122
 
123
 
124
- def get_litellm_model_name(provider: str, model: str) -> str:
125
  """
126
  Convert provider and model to LiteLLM model name format.
127
 
@@ -184,6 +184,8 @@ def stream_litellm_completion(
184
  litellm_model = f'azure/{azure_deployment_name}'
185
  else:
186
  litellm_model = get_litellm_model_name(provider, model)
 
 
187
 
188
  # Prepare the request parameters
189
  request_params = {
 
5
  import re
6
  import sys
7
  import urllib3
8
+ from typing import Tuple, Union, Iterator, Optional
9
 
10
  import requests
11
  import os
 
121
  return True
122
 
123
 
124
+ def get_litellm_model_name(provider: str, model: str) -> Optional[str]:
125
  """
126
  Convert provider and model to LiteLLM model name format.
127
 
 
184
  litellm_model = f'azure/{azure_deployment_name}'
185
  else:
186
  litellm_model = get_litellm_model_name(provider, model)
187
+ if not litellm_model:
188
+ raise ValueError(f"Invalid model name: {model} for provider: {provider}")
189
 
190
  # Prepare the request parameters
191
  request_params = {