seanpedrickcase commited on
Commit
3085585
·
1 Parent(s): 5fed34d

Added model compatibility for OpenAI and Azure endpoints. Added some Bedrock models, now compatible with thinking models

Browse files
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import gradio as gr
4
  import pandas as pd
5
  from datetime import datetime
6
- from tools.helper_functions import put_columns_in_df, get_connection_params, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df, load_in_previous_data_files, load_in_data_file, load_in_default_cost_codes, reset_base_dataframe, update_cost_code_dataframe_from_dropdown_select, df_select_callback_cost, enforce_cost_codes, _get_env_list, move_overall_summary_output_files_to_front_page
7
  from tools.aws_functions import upload_file_to_s3, download_file_from_s3
8
  from tools.llm_api_call import modify_existing_output_tables, wrapper_extract_topics_per_column_value, all_in_one_pipeline
9
  from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, deduplicate_topics_llm, overall_summary
@@ -13,18 +13,7 @@ from tools.auth import authenticate_user
13
  from tools.example_table_outputs import dummy_consultation_table, case_notes_table, dummy_consultation_table_zero_shot, case_notes_table_grouped, case_notes_table_structured_summary
14
  from tools.prompts import initial_table_prompt, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
15
  # from tools.verify_titles import verify_titles
16
- from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, RUN_LOCAL_MODEL, FILE_INPUT_HEIGHT, GEMINI_API_KEY, model_full_names, BATCH_SIZE_DEFAULT, CHOSEN_LOCAL_MODEL_TYPE, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, model_name_map, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN, AZURE_API_KEY, LLM_TEMPERATURE
17
-
18
- def ensure_folder_exists(output_folder:str):
19
- """Checks if the specified folder exists, creates it if not."""
20
-
21
- if not os.path.exists(output_folder):
22
- # Create the folder if it doesn't exist
23
- os.makedirs(output_folder, exist_ok=True)
24
- print(f"Created the {output_folder} folder.")
25
- else:
26
- pass
27
- #print(f"The {output_folder} folder already exists.")
28
 
29
  ensure_folder_exists(CONFIG_FOLDER)
30
  ensure_folder_exists(OUTPUT_FOLDER)
@@ -35,26 +24,8 @@ ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
35
  ensure_folder_exists(ACCESS_LOGS_FOLDER)
36
  ensure_folder_exists(USAGE_LOGS_FOLDER)
37
 
38
- # Convert string environment variables to string or list
39
- if SAVE_LOGS_TO_CSV == "True": SAVE_LOGS_TO_CSV = True
40
- else: SAVE_LOGS_TO_CSV = False
41
- if SAVE_LOGS_TO_DYNAMODB == "True": SAVE_LOGS_TO_DYNAMODB = True
42
- else: SAVE_LOGS_TO_DYNAMODB = False
43
-
44
- if CSV_ACCESS_LOG_HEADERS: CSV_ACCESS_LOG_HEADERS = _get_env_list(CSV_ACCESS_LOG_HEADERS)
45
- if CSV_FEEDBACK_LOG_HEADERS: CSV_FEEDBACK_LOG_HEADERS = _get_env_list(CSV_FEEDBACK_LOG_HEADERS)
46
- if CSV_USAGE_LOG_HEADERS: CSV_USAGE_LOG_HEADERS = _get_env_list(CSV_USAGE_LOG_HEADERS)
47
-
48
- if DYNAMODB_ACCESS_LOG_HEADERS: DYNAMODB_ACCESS_LOG_HEADERS = _get_env_list(DYNAMODB_ACCESS_LOG_HEADERS)
49
- if DYNAMODB_FEEDBACK_LOG_HEADERS: DYNAMODB_FEEDBACK_LOG_HEADERS = _get_env_list(DYNAMODB_FEEDBACK_LOG_HEADERS)
50
- if DYNAMODB_USAGE_LOG_HEADERS: DYNAMODB_USAGE_LOG_HEADERS = _get_env_list(DYNAMODB_USAGE_LOG_HEADERS)
51
-
52
  today_rev = datetime.now().strftime("%Y%m%d")
53
 
54
- if RUN_LOCAL_MODEL == "1": default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
55
- elif RUN_AWS_FUNCTIONS == "1": default_model_choice = "anthropic.claude-3-haiku-20240307-v1:0"
56
- else: default_model_choice = "gemini-2.5-flash"
57
-
58
  # Placeholders for example variables
59
  in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
60
  in_colnames = gr.Dropdown(choices=[""], multiselect = False, label="Select the open text column of interest. In an Excel file, this shows columns across all sheets.", allow_custom_value=True, interactive=True)
@@ -162,7 +133,7 @@ with app:
162
 
163
  gr.Markdown("""# Large language model topic modelling
164
 
165
- Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets below. API keys for AWS, Azure, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
166
 
167
  NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
168
 
@@ -198,7 +169,10 @@ with app:
198
 
199
  with gr.Tab(label="All in one topic extraction and summarisation"):
200
  with gr.Row():
201
- model_choice = gr.Dropdown(value = default_model_choice, choices = model_full_names, label="Large language model for topic extraction and summarisation", multiselect=False)
 
 
 
202
 
203
  with gr.Accordion("Upload xlsx, csv, or parquet file", open = True):
204
  in_data_files.render()
@@ -339,8 +313,9 @@ with app:
339
  with gr.Accordion("Gemini API keys", open = False):
340
  google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
341
 
342
- with gr.Accordion("Azure AI Inference", open = False):
343
- azure_api_key_textbox = gr.Textbox(value = AZURE_API_KEY, label="Enter Azure AI Inference API key (only if using Azure models)", lines=1, type="password")
 
344
 
345
  with gr.Accordion("Hugging Face token for downloading gated models", open = False):
346
  hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only for gated models that need a token to download)", lines=1, type="password")
@@ -435,6 +410,7 @@ with app:
435
  aws_secret_key_textbox,
436
  hf_api_key_textbox,
437
  azure_api_key_textbox,
 
438
  output_folder_state,
439
  logged_content_df,
440
  add_existing_topics_summary_format_textbox],
@@ -478,12 +454,12 @@ with app:
478
  success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames, output_folder_state], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown], scroll_to_output=True, api_name="deduplicate_topics")
479
 
480
  # When LLM deduplication button pressed, deduplicate data using LLM
481
- def deduplicate_topics_llm_wrapper(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics=None):
482
  model_source = model_name_map[model_choice]["source"]
483
- return deduplicate_topics_llm(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, model_source, None, None, None, None, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics)
484
 
485
  deduplicate_llm_previous_data_btn.click(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
486
- success(deduplicate_topics_llm_wrapper, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, model_choice, google_api_key_textbox, temperature_slide, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, in_data_files, in_colnames, output_folder_state, candidate_topics], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="deduplicate_topics_llm").\
487
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs_llm_dedup")
488
 
489
  # When button pressed, summarise previous data
@@ -491,14 +467,14 @@ with app:
491
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
492
  success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
493
  success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
494
- success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, logged_content_df], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], api_name="summarise_topics", show_progress_on=[output_messages_textbox, summary_output_files]).\
495
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
496
  then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_revised_summaries_state, master_unique_topics_df_revised_summaries_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[summary_output_files_xlsx, summary_xlsx_output_files_list])
497
 
498
  # SUMMARISE WHOLE TABLE PAGE
499
  overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
500
  success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
501
- success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, logged_content_df], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], scroll_to_output=True, api_name="overall_summary", show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
502
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
503
  then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list])
504
 
@@ -545,6 +521,7 @@ with app:
545
  aws_secret_key_textbox,
546
  hf_api_key_textbox,
547
  azure_api_key_textbox,
 
548
  output_folder_state,
549
  merge_sentiment_drop,
550
  merge_general_topics_drop,
 
3
  import gradio as gr
4
  import pandas as pd
5
  from datetime import datetime
6
+ from tools.helper_functions import put_columns_in_df, get_connection_params, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df, load_in_previous_data_files, load_in_data_file, load_in_default_cost_codes, reset_base_dataframe, update_cost_code_dataframe_from_dropdown_select, df_select_callback_cost, enforce_cost_codes, _get_env_list, move_overall_summary_output_files_to_front_page, update_model_choice
7
  from tools.aws_functions import upload_file_to_s3, download_file_from_s3
8
  from tools.llm_api_call import modify_existing_output_tables, wrapper_extract_topics_per_column_value, all_in_one_pipeline
9
  from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, deduplicate_topics_llm, overall_summary
 
13
  from tools.example_table_outputs import dummy_consultation_table, case_notes_table, dummy_consultation_table_zero_shot, case_notes_table_grouped, case_notes_table_structured_summary
14
  from tools.prompts import initial_table_prompt, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
15
  # from tools.verify_titles import verify_titles
16
+ from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, FILE_INPUT_HEIGHT, GEMINI_API_KEY, BATCH_SIZE_DEFAULT, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN, AZURE_OPENAI_API_KEY, AZURE_OPENAI_INFERENCE_ENDPOINT, LLM_TEMPERATURE, model_name_map, default_model_choice, default_source_models, default_model_source, model_sources, ensure_folder_exists
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  ensure_folder_exists(CONFIG_FOLDER)
19
  ensure_folder_exists(OUTPUT_FOLDER)
 
24
  ensure_folder_exists(ACCESS_LOGS_FOLDER)
25
  ensure_folder_exists(USAGE_LOGS_FOLDER)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  today_rev = datetime.now().strftime("%Y%m%d")
28
 
 
 
 
 
29
  # Placeholders for example variables
30
  in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
31
  in_colnames = gr.Dropdown(choices=[""], multiselect = False, label="Select the open text column of interest. In an Excel file, this shows columns across all sheets.", allow_custom_value=True, interactive=True)
 
133
 
134
  gr.Markdown("""# Large language model topic modelling
135
 
136
+ Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure/OpenAI, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets below. API keys for AWS, Azure/OpenAI, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
137
 
138
  NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
139
 
 
169
 
170
  with gr.Tab(label="All in one topic extraction and summarisation"):
171
  with gr.Row():
172
+ model_source = gr.Dropdown(value = default_model_source, choices = model_sources, label="Large language model family", multiselect=False)
173
+ model_choice = gr.Dropdown(value = default_model_choice, choices = default_source_models, label="Large language model for topic extraction and summarisation", multiselect=False)
174
+
175
+ model_source.change(fn=update_model_choice, inputs=[model_source], outputs=[model_choice])
176
 
177
  with gr.Accordion("Upload xlsx, csv, or parquet file", open = True):
178
  in_data_files.render()
 
313
  with gr.Accordion("Gemini API keys", open = False):
314
  google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
315
 
316
+ with gr.Accordion("Azure/OpenAI Inference", open = False):
317
+ azure_api_key_textbox = gr.Textbox(value = AZURE_OPENAI_API_KEY, label="Enter Azure/OpenAI Inference API key (only if using Azure/OpenAI models)", lines=1, type="password")
318
+ azure_endpoint_textbox = gr.Textbox(value = AZURE_OPENAI_INFERENCE_ENDPOINT, label="Enter Azure/OpenAI Inference endpoint URL (only if using Azure/OpenAI models)", lines=1)
319
 
320
  with gr.Accordion("Hugging Face token for downloading gated models", open = False):
321
  hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only for gated models that need a token to download)", lines=1, type="password")
 
410
  aws_secret_key_textbox,
411
  hf_api_key_textbox,
412
  azure_api_key_textbox,
413
+ azure_endpoint_textbox,
414
  output_folder_state,
415
  logged_content_df,
416
  add_existing_topics_summary_format_textbox],
 
454
  success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames, output_folder_state], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown], scroll_to_output=True, api_name="deduplicate_topics")
455
 
456
  # When LLM deduplication button pressed, deduplicate data using LLM
457
+ def deduplicate_topics_llm_wrapper(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics=None, azure_endpoint=""):
458
  model_source = model_name_map[model_choice]["source"]
459
+ return deduplicate_topics_llm(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, model_source, None, None, None, None, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics, azure_endpoint)
460
 
461
  deduplicate_llm_previous_data_btn.click(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
462
+ success(deduplicate_topics_llm_wrapper, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, model_choice, google_api_key_textbox, temperature_slide, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, in_data_files, in_colnames, output_folder_state, candidate_topics, azure_endpoint_textbox], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="deduplicate_topics_llm").\
463
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs_llm_dedup")
464
 
465
  # When button pressed, summarise previous data
 
467
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
468
  success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
469
  success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
470
+ success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, azure_endpoint_textbox, logged_content_df], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], api_name="summarise_topics", show_progress_on=[output_messages_textbox, summary_output_files]).\
471
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
472
  then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_revised_summaries_state, master_unique_topics_df_revised_summaries_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[summary_output_files_xlsx, summary_xlsx_output_files_list])
473
 
474
  # SUMMARISE WHOLE TABLE PAGE
475
  overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
476
  success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
477
+ success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, azure_endpoint_textbox, logged_content_df], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], scroll_to_output=True, api_name="overall_summary", show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
478
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
479
  then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list])
480
 
 
521
  aws_secret_key_textbox,
522
  hf_api_key_textbox,
523
  azure_api_key_textbox,
524
+ azure_endpoint_textbox,
525
  output_folder_state,
526
  merge_sentiment_drop,
527
  merge_general_topics_drop,
pyproject.toml CHANGED
@@ -17,8 +17,7 @@ dependencies = [
17
  "tabulate==0.9.0",
18
  "lxml==5.3.0",
19
  "google-genai==1.33.0",
20
- "azure-ai-inference==1.0.0b9",
21
- "azure-core==1.35.0",
22
  "html5lib==1.1",
23
  "beautifulsoup4==4.12.3",
24
  "rapidfuzz==3.13.0",
 
17
  "tabulate==0.9.0",
18
  "lxml==5.3.0",
19
  "google-genai==1.33.0",
20
+ "openai==2.2.0",
 
21
  "html5lib==1.1",
22
  "beautifulsoup4==4.12.3",
23
  "rapidfuzz==3.13.0",
requirements.txt CHANGED
@@ -10,8 +10,7 @@ markdown==3.7
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
  google-genai==1.33.0
13
- azure-ai-inference==1.0.0b9
14
- azure-core==1.35.0
15
  html5lib==1.1
16
  beautifulsoup4==4.12.3
17
  rapidfuzz==3.13.0
 
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
  google-genai==1.33.0
13
+ openai==2.2.0
 
14
  html5lib==1.1
15
  beautifulsoup4==4.12.3
16
  rapidfuzz==3.13.0
requirements_cpu.txt CHANGED
@@ -9,8 +9,7 @@ markdown==3.7
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
  google-genai==1.33.0
12
- azure-ai-inference==1.0.0b9
13
- azure-core==1.35.0
14
  html5lib==1.1
15
  beautifulsoup4==4.12.3
16
  rapidfuzz==3.13.0
 
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
  google-genai==1.33.0
12
+ openai==2.2.0
 
13
  html5lib==1.1
14
  beautifulsoup4==4.12.3
15
  rapidfuzz==3.13.0
requirements_gpu.txt CHANGED
@@ -9,8 +9,7 @@ markdown==3.7
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
  google-genai==1.33.0
12
- azure-ai-inference==1.0.0b9
13
- azure-core==1.35.0
14
  html5lib==1.1
15
  beautifulsoup4==4.12.3
16
  rapidfuzz==3.13.0
 
9
  tabulate==0.9.0
10
  lxml==5.3.0
11
  google-genai==1.33.0
12
+ openai==2.2.0
 
13
  html5lib==1.1
14
  beautifulsoup4==4.12.3
15
  rapidfuzz==3.13.0
requirements_no_local.txt CHANGED
@@ -10,8 +10,7 @@ markdown==3.7
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
  google-genai==1.33.0
13
- azure-ai-inference==1.0.0b9
14
- azure-core==1.35.0
15
  html5lib==1.1
16
  beautifulsoup4==4.12.3
17
  rapidfuzz==3.13.0
 
10
  tabulate==0.9.0
11
  lxml==5.3.0
12
  google-genai==1.33.0
13
+ openai==2.2.0
 
14
  html5lib==1.1
15
  beautifulsoup4==4.12.3
16
  rapidfuzz==3.13.0
tools/config.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import tempfile
3
  import socket
4
  import logging
 
 
5
  from datetime import datetime
6
  from dotenv import load_dotenv
7
 
@@ -217,10 +219,10 @@ RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
217
  RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
218
  GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
219
 
220
- # Azure AI Inference settings
221
- RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "0")
222
- AZURE_API_KEY = get_or_create_env_var('AZURE_API_KEY', '')
223
- AZURE_INFERENCE_ENDPOINT = get_or_create_env_var('AZURE_INFERENCE_ENDPOINT', '')
224
 
225
  # Build up options for models
226
 
@@ -236,27 +238,48 @@ if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
236
  model_source.append("Local")
237
 
238
  if RUN_AWS_BEDROCK_MODELS == "1":
239
- model_full_names.extend(["anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-7-sonnet-20250219-v1:0", "amazon.nova-micro-v1:0", "amazon.nova-lite-v1:0", "amazon.nova-pro-v1:0"])
240
- model_short_names.extend(["haiku", "sonnet", "nova_micro", "nova_lite", "nova_pro"])
241
- model_source.extend(["AWS", "AWS", "AWS", "AWS", "AWS"])
 
242
 
243
  if RUN_GEMINI_MODELS == "1":
244
- model_full_names.extend(["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"])
 
245
  model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
246
- model_source.extend(["Gemini", "Gemini", "Gemini"])
247
 
248
- # Register Azure AI models (model names must match your Azure deployments)
249
  if RUN_AZURE_MODELS == "1":
250
- # Example deployments; adjust to the deployments you actually create in Azure
251
- model_full_names.extend(["gpt-5-mini"])
252
- model_short_names.extend(["gpt-5-mini"])
253
- model_source.extend(["Azure"])
 
254
 
255
  model_name_map = {
256
  full: {"short_name": short, "source": source}
257
  for full, short, source in zip(model_full_names, model_short_names, model_source)
258
  }
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  #print("model_name_map:", model_name_map)
261
 
262
  # HF token may or may not be needed for downloading models from Hugging Face
@@ -453,4 +476,44 @@ else: OUTPUT_COST_CODES_PATH = 'config/cost_codes.csv'
453
 
454
  ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
455
 
456
- if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import tempfile
3
  import socket
4
  import logging
5
+ import codecs
6
+ from typing import List
7
  from datetime import datetime
8
  from dotenv import load_dotenv
9
 
 
219
  RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
220
  GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
221
 
222
+ # Azure/OpenAI AI Inference settings
223
+ RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "1")
224
+ AZURE_OPENAI_API_KEY = get_or_create_env_var('AZURE_OPENAI_API_KEY', '')
225
+ AZURE_OPENAI_INFERENCE_ENDPOINT = get_or_create_env_var('AZURE_OPENAI_INFERENCE_ENDPOINT', '')
226
 
227
  # Build up options for models
228
 
 
238
  model_source.append("Local")
239
 
240
  if RUN_AWS_BEDROCK_MODELS == "1":
241
+ amazon_models = ["anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-7-sonnet-20250219-v1:0", "anthropic.claude-sonnet-4-5-20250929-v1:0", "amazon.nova-micro-v1:0", "amazon.nova-lite-v1:0", "amazon.nova-pro-v1:0", "deepseek.v3-v1:0", "openai.gpt-oss-20b-1:0", "openai.gpt-oss-120b-1:0"]
242
+ model_full_names.extend(amazon_models)
243
+ model_short_names.extend(["haiku", "sonnet_3_7", "sonnet_4_5", "nova_micro", "nova_lite", "nova_pro", "deepseek_v3", "gpt_oss_20b_aws", "gpt_oss_120b_aws"])
244
+ model_source.extend(["AWS"] * len(amazon_models))
245
 
246
  if RUN_GEMINI_MODELS == "1":
247
+ gemini_models = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"]
248
+ model_full_names.extend(gemini_models)
249
  model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
250
+ model_source.extend(["Gemini"] * len(gemini_models))
251
 
252
+ # Register Azure/OpenAI AI models (model names must match your Azure/OpenAI deployments)
253
  if RUN_AZURE_MODELS == "1":
254
+ # Example deployments; adjust to the deployments you actually create in Azure/OpenAI
255
+ azure_models = ["gpt-5-mini", "gpt-4o-mini"]
256
+ model_full_names.extend(azure_models)
257
+ model_short_names.extend(["gpt-5-mini", "gpt-4o-mini"])
258
+ model_source.extend(["Azure/OpenAI"] * len(azure_models))
259
 
260
  model_name_map = {
261
  full: {"short_name": short, "source": source}
262
  for full, short, source in zip(model_full_names, model_short_names, model_source)
263
  }
264
 
265
+ if RUN_LOCAL_MODEL == "1": default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
266
+ elif RUN_AWS_FUNCTIONS == "1": default_model_choice = amazon_models[0]
267
+ else: default_model_choice = gemini_models[0]
268
+
269
+ default_model_source = model_name_map[default_model_choice]["source"]
270
+ model_sources = list(set([model_name_map[model]["source"] for model in model_full_names]))
271
+
272
+ def update_model_choice_config(default_model_source, model_name_map):
273
+ # Filter models by source and return the first matching model name
274
+ matching_models = [model_name for model_name, model_info in model_name_map.items()
275
+ if model_info["source"] == default_model_source]
276
+
277
+ output_model = matching_models[0] if matching_models else model_full_names[0]
278
+
279
+ return output_model, matching_models
280
+
281
+ default_model_choice, default_source_models = update_model_choice_config(default_model_source, model_name_map)
282
+
283
  #print("model_name_map:", model_name_map)
284
 
285
  # HF token may or may not be needed for downloading models from Hugging Face
 
476
 
477
  ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
478
 
479
+ if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
480
+
481
+ ###
482
+ # VALIDATE FOLDERS AND CONFIG OPTIONS
483
+ ###
484
+
485
+ def ensure_folder_exists(output_folder:str):
486
+ """Checks if the specified folder exists, creates it if not."""
487
+
488
+ if not os.path.exists(output_folder):
489
+ # Create the folder if it doesn't exist
490
+ os.makedirs(output_folder, exist_ok=True)
491
+ print(f"Created the {output_folder} folder.")
492
+ else:
493
+ pass
494
+ #print(f"The {output_folder} folder already exists.")
495
+
496
+ def _get_env_list(env_var_name: str, strip_strings:bool=True) -> List[str]:
497
+ """Parses a comma-separated environment variable into a list of strings."""
498
+ value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
499
+ if not value:
500
+ return []
501
+ # Split by comma and filter out any empty strings that might result from extra commas
502
+ if strip_strings:
503
+ return [s.strip() for s in value.split(',') if s.strip()]
504
+ else:
505
+ return [codecs.decode(s, 'unicode_escape') for s in value.split(',') if s]
506
+
507
+ # Convert string environment variables to string or list
508
+ if SAVE_LOGS_TO_CSV == "True": SAVE_LOGS_TO_CSV = True
509
+ else: SAVE_LOGS_TO_CSV = False
510
+ if SAVE_LOGS_TO_DYNAMODB == "True": SAVE_LOGS_TO_DYNAMODB = True
511
+ else: SAVE_LOGS_TO_DYNAMODB = False
512
+
513
+ if CSV_ACCESS_LOG_HEADERS: CSV_ACCESS_LOG_HEADERS = _get_env_list(CSV_ACCESS_LOG_HEADERS)
514
+ if CSV_FEEDBACK_LOG_HEADERS: CSV_FEEDBACK_LOG_HEADERS = _get_env_list(CSV_FEEDBACK_LOG_HEADERS)
515
+ if CSV_USAGE_LOG_HEADERS: CSV_USAGE_LOG_HEADERS = _get_env_list(CSV_USAGE_LOG_HEADERS)
516
+
517
+ if DYNAMODB_ACCESS_LOG_HEADERS: DYNAMODB_ACCESS_LOG_HEADERS = _get_env_list(DYNAMODB_ACCESS_LOG_HEADERS)
518
+ if DYNAMODB_FEEDBACK_LOG_HEADERS: DYNAMODB_FEEDBACK_LOG_HEADERS = _get_env_list(DYNAMODB_FEEDBACK_LOG_HEADERS)
519
+ if DYNAMODB_USAGE_LOG_HEADERS: DYNAMODB_USAGE_LOG_HEADERS = _get_env_list(DYNAMODB_USAGE_LOG_HEADERS)
tools/dedup_summaries.py CHANGED
@@ -11,10 +11,10 @@ from tqdm import tqdm
11
  import os
12
  from tools.llm_api_call import generate_zero_shot_topics_df
13
  from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill, llm_deduplication_system_prompt, llm_deduplication_prompt, llm_deduplication_prompt_with_candidates
14
- from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model, send_request, construct_gemini_generative_model, construct_azure_client, call_llm_with_markdown_table_checks
15
  from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details, read_file
16
  from tools.aws_functions import connect_to_bedrock_runtime
17
- from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, LLM_MAX_NEW_TOKENS, LLM_SEED, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT, MAX_SPACES_GPU_RUN_TIME, OUTPUT_DEBUG_FILES
18
 
19
  max_tokens = LLM_MAX_NEW_TOKENS
20
  timeout_wait = TIMEOUT_WAIT
@@ -393,7 +393,8 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
393
  in_data_files:List[str]=list(),
394
  chosen_cols:List[str]="",
395
  output_folder:str=OUTPUT_FOLDER,
396
- candidate_topics=None
 
397
  ):
398
  '''
399
  Deduplicate topics using LLM semantic understanding to identify and merge similar topics.
@@ -501,7 +502,7 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
501
 
502
  # Set up model clients based on model source
503
  if "Gemini" in model_source:
504
- google_client, config = construct_gemini_generative_model(
505
  in_api_key, temperature, model_choice, llm_deduplication_system_prompt,
506
  max_tokens, LLM_SEED
507
  )
@@ -509,13 +510,13 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
509
  elif "AWS" in model_source:
510
  if not bedrock_runtime:
511
  bedrock_runtime = boto3.client('bedrock-runtime')
512
- google_client = None
513
  config = None
514
- elif "Azure" in model_source:
515
- google_client, config = construct_azure_client(in_api_key, "")
516
  bedrock_runtime = None
517
  elif "Local" in model_source:
518
- google_client = None
519
  config = None
520
  bedrock_runtime = None
521
  else:
@@ -531,8 +532,8 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
531
  conversation_history=conversation_history,
532
  whole_conversation=whole_conversation,
533
  whole_conversation_metadata=whole_conversation_metadata,
534
- google_client=google_client,
535
- google_config=config,
536
  model_choice=model_choice,
537
  temperature=temperature,
538
  reported_batch_no=1,
@@ -758,7 +759,7 @@ def sample_reference_table_summaries(reference_df:pd.DataFrame,
758
 
759
  return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
760
 
761
- def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list(), tokenizer=list(), assistant_model=list()):
762
  """
763
  Query an LLM to generate a summary of topics based on the provided prompts.
764
 
@@ -780,16 +781,15 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
780
  """
781
  conversation_history = list()
782
  whole_conversation_metadata = list()
783
- google_client = list()
784
- google_config = {}
785
 
786
  # Prepare Gemini models before query
787
  if "Gemini" in model_source:
788
  #print("Using Gemini model:", model_choice)
789
- google_client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
790
- elif "Azure" in model_source:
791
- # Azure client (endpoint from env/config)
792
- google_client, config = construct_azure_client(in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), endpoint=AZURE_INFERENCE_ENDPOINT)
793
  elif "Local" in model_source:
794
  pass
795
  #print("Using local model: ", model_choice)
@@ -800,7 +800,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
800
  whole_conversation = [summarise_topic_descriptions_system_prompt]
801
 
802
  # Process requests to large language model
803
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, tokenizer=tokenizer, assistant_model=assistant_model, assistant_prefill=summary_assistant_prefill)
804
 
805
  summarised_output = re.sub(r'\n{2,}', '\n', response_text) # Replace multiple line breaks with a single line break
806
  summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
@@ -898,6 +898,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
898
  aws_secret_key_textbox:str='',
899
  model_name_map:dict=model_name_map,
900
  hf_api_key_textbox:str='',
 
901
  existing_logged_content:list=list(),
902
  output_debug_files:str=output_debug_files,
903
  reasoning_suffix:str=reasoning_suffix,
@@ -1041,7 +1042,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
1041
  if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
1042
 
1043
  try:
1044
- response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model)
1045
  summarised_output = response_text
1046
  except Exception as e:
1047
  print("Creating summary failed:", e)
@@ -1183,6 +1184,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
1183
  aws_secret_key_textbox:str='',
1184
  model_name_map:dict=model_name_map,
1185
  hf_api_key_textbox:str='',
 
1186
  existing_logged_content:list=list(),
1187
  output_debug_files:str=output_debug_files,
1188
  log_output_files:list=list(),
@@ -1313,7 +1315,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
1313
  if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
1314
 
1315
  try:
1316
- response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model)
1317
  summarised_output_for_df = response_text
1318
  summarised_output = response
1319
  except Exception as e:
 
11
  import os
12
  from tools.llm_api_call import generate_zero_shot_topics_df
13
  from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill, llm_deduplication_system_prompt, llm_deduplication_prompt, llm_deduplication_prompt_with_candidates
14
+ from tools.llm_funcs import construct_gemini_generative_model, process_requests, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model, construct_gemini_generative_model, construct_azure_client, call_llm_with_markdown_table_checks
15
  from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details, read_file
16
  from tools.aws_functions import connect_to_bedrock_runtime
17
+ from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, LLM_MAX_NEW_TOKENS, LLM_SEED, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_OPENAI_INFERENCE_ENDPOINT, MAX_SPACES_GPU_RUN_TIME, OUTPUT_DEBUG_FILES
18
 
19
  max_tokens = LLM_MAX_NEW_TOKENS
20
  timeout_wait = TIMEOUT_WAIT
 
393
  in_data_files:List[str]=list(),
394
  chosen_cols:List[str]="",
395
  output_folder:str=OUTPUT_FOLDER,
396
+ candidate_topics=None,
397
+ azure_endpoint:str=""
398
  ):
399
  '''
400
  Deduplicate topics using LLM semantic understanding to identify and merge similar topics.
 
502
 
503
  # Set up model clients based on model source
504
  if "Gemini" in model_source:
505
+ client, config = construct_gemini_generative_model(
506
  in_api_key, temperature, model_choice, llm_deduplication_system_prompt,
507
  max_tokens, LLM_SEED
508
  )
 
510
  elif "AWS" in model_source:
511
  if not bedrock_runtime:
512
  bedrock_runtime = boto3.client('bedrock-runtime')
513
+ client = None
514
  config = None
515
+ elif "Azure/OpenAI" in model_source:
516
+ client, config = construct_azure_client(in_api_key, azure_endpoint)
517
  bedrock_runtime = None
518
  elif "Local" in model_source:
519
+ client = None
520
  config = None
521
  bedrock_runtime = None
522
  else:
 
532
  conversation_history=conversation_history,
533
  whole_conversation=whole_conversation,
534
  whole_conversation_metadata=whole_conversation_metadata,
535
+ client=client,
536
+ client_config=config,
537
  model_choice=model_choice,
538
  temperature=temperature,
539
  reported_batch_no=1,
 
759
 
760
  return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
761
 
762
+ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list(), tokenizer=list(), assistant_model=list(), azure_endpoint:str=""):
763
  """
764
  Query an LLM to generate a summary of topics based on the provided prompts.
765
 
 
781
  """
782
  conversation_history = list()
783
  whole_conversation_metadata = list()
784
+ client = list()
785
+ client_config = {}
786
 
787
  # Prepare Gemini models before query
788
  if "Gemini" in model_source:
789
  #print("Using Gemini model:", model_choice)
790
+ client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
791
+ elif "Azure/OpenAI" in model_source:
792
+ client, config = construct_azure_client(in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), endpoint=azure_endpoint)
 
793
  elif "Local" in model_source:
794
  pass
795
  #print("Using local model: ", model_choice)
 
800
  whole_conversation = [summarise_topic_descriptions_system_prompt]
801
 
802
  # Process requests to large language model
803
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, client, client_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, tokenizer=tokenizer, assistant_model=assistant_model, assistant_prefill=summary_assistant_prefill)
804
 
805
  summarised_output = re.sub(r'\n{2,}', '\n', response_text) # Replace multiple line breaks with a single line break
806
  summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
 
898
  aws_secret_key_textbox:str='',
899
  model_name_map:dict=model_name_map,
900
  hf_api_key_textbox:str='',
901
+ azure_endpoint_textbox:str='',
902
  existing_logged_content:list=list(),
903
  output_debug_files:str=output_debug_files,
904
  reasoning_suffix:str=reasoning_suffix,
 
1042
  if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
1043
 
1044
  try:
1045
+ response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model, azure_endpoint=azure_endpoint_textbox)
1046
  summarised_output = response_text
1047
  except Exception as e:
1048
  print("Creating summary failed:", e)
 
1184
  aws_secret_key_textbox:str='',
1185
  model_name_map:dict=model_name_map,
1186
  hf_api_key_textbox:str='',
1187
+ azure_endpoint_textbox:str='',
1188
  existing_logged_content:list=list(),
1189
  output_debug_files:str=output_debug_files,
1190
  log_output_files:list=list(),
 
1315
  if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
1316
 
1317
  try:
1318
+ response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model, azure_endpoint=azure_endpoint_textbox)
1319
  summarised_output_for_df = response_text
1320
  summarised_output = response
1321
  except Exception as e:
tools/helper_functions.py CHANGED
@@ -6,8 +6,9 @@ import pandas as pd
6
  import numpy as np
7
  from typing import List
8
  import math
 
9
  from botocore.exceptions import ClientError
10
- from tools.config import OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, CUSTOM_HEADER, CUSTOM_HEADER_VALUE, AWS_USER_POOL_ID, MAXIMUM_ZERO_SHOT_TOPICS
11
 
12
  def empty_output_vars_extract_topics():
13
  # Empty output objects before processing a new file
@@ -745,8 +746,6 @@ def enforce_cost_codes(enforce_cost_code_textbox:str, cost_code_choice:str, cost
745
  raise Exception("Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions.")
746
  return
747
 
748
- import codecs
749
-
750
  def _get_env_list(env_var_name: str, strip_strings:bool=True) -> List[str]:
751
  """Parses a comma-separated environment variable into a list of strings."""
752
  value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
@@ -898,4 +897,13 @@ def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
898
  "Description": zero_shot_topics_description_list
899
  })
900
 
901
- return zero_shot_topics_df
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
  from typing import List
8
  import math
9
+ import codecs
10
  from botocore.exceptions import ClientError
11
+ from tools.config import OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, CUSTOM_HEADER, CUSTOM_HEADER_VALUE, AWS_USER_POOL_ID, MAXIMUM_ZERO_SHOT_TOPICS, model_name_map, model_full_names
12
 
13
  def empty_output_vars_extract_topics():
14
  # Empty output objects before processing a new file
 
746
  raise Exception("Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions.")
747
  return
748
 
 
 
749
  def _get_env_list(env_var_name: str, strip_strings:bool=True) -> List[str]:
750
  """Parses a comma-separated environment variable into a list of strings."""
751
  value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
 
897
  "Description": zero_shot_topics_description_list
898
  })
899
 
900
+ return zero_shot_topics_df
901
+
902
+ def update_model_choice(model_source):
903
+ # Filter models by source and return the first matching model name
904
+ matching_models = [model_name for model_name, model_info in model_name_map.items()
905
+ if model_info["source"] == model_source]
906
+
907
+ output_model = matching_models[0] if matching_models else model_full_names[0]
908
+
909
+ return gr.Dropdown(value = output_model, choices = matching_models, label="Large language model for topic extraction and summarisation", multiselect=False)
tools/llm_api_call.py CHANGED
@@ -18,7 +18,7 @@ GradioFileData = gr.FileData
18
  from tools.prompts import initial_table_prompt, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt, default_response_reference_format, negative_neutral_positive_sentiment_prompt, negative_or_positive_sentiment_prompt, default_sentiment_prompt
19
  from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details, move_overall_summary_output_files_to_front_page, generate_zero_shot_topics_df
20
  from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
21
- from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, LLM_MAX_NEW_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT, MAX_ROWS, MAXIMUM_ZERO_SHOT_TOPICS, MAX_SPACES_GPU_RUN_TIME, OUTPUT_DEBUG_FILES
22
  from tools.aws_functions import connect_to_bedrock_runtime
23
  from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary, process_debug_output_iteration
24
  from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
@@ -620,6 +620,7 @@ def extract_topics(in_data_file: GradioFileData,
620
  aws_secret_key_textbox:str='',
621
  hf_api_key_textbox:str='',
622
  azure_api_key_textbox:str='',
 
623
  max_tokens:int=max_tokens,
624
  model_name_map:dict=model_name_map,
625
  existing_logged_content:list=list(),
@@ -635,7 +636,7 @@ def extract_topics(in_data_file: GradioFileData,
635
  progress=Progress(track_tqdm=False)):
636
 
637
  '''
638
- Query an LLM (local, (Gemma/GPT-OSS if local, Gemini, AWS Bedrock or Azure AI Inference) with up to three prompts about a table of open text data. Up to 'batch_size' rows will be queried at a time.
639
 
640
  Parameters:
641
  - in_data_file (gr.File): Gradio file object containing input data
@@ -693,8 +694,8 @@ def extract_topics(in_data_file: GradioFileData,
693
 
694
  tic = time.perf_counter()
695
 
696
- google_client = list()
697
- google_config = {}
698
  final_time = 0.0
699
  whole_conversation_metadata = list()
700
  is_error = False
@@ -822,13 +823,13 @@ def extract_topics(in_data_file: GradioFileData,
822
  # Prepare clients before query
823
  if "Gemini" in model_source:
824
  #print("Using Gemini model:", model_choice)
825
- google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
826
- elif "Azure" in model_source:
827
- #print("Using Azure AI Inference model:", model_choice)
828
  # If provided, set env for downstream calls too
829
  if azure_api_key_textbox:
830
  os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
831
- google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
832
  elif "anthropic.claude" in model_choice:
833
  #print("Using AWS Bedrock model:", model_choice)
834
  pass
@@ -949,7 +950,7 @@ def extract_topics(in_data_file: GradioFileData,
949
  whole_conversation = list()
950
 
951
  # Process requests to large language model
952
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, tokenizer, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
953
 
954
  # Return output tables
955
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=False, output_folder=output_folder)
@@ -1011,12 +1012,12 @@ def extract_topics(in_data_file: GradioFileData,
1011
  # Prepare Gemini models before query
1012
  if model_source == "Gemini":
1013
  print("Using Gemini model:", model_choice)
1014
- google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
1015
- elif model_source == "Azure":
1016
- print("Using Azure AI Inference model:", model_choice)
1017
  if azure_api_key_textbox:
1018
  os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
1019
- google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
1020
  elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
1021
  pass
1022
  #print("Using local model:", model_choice)
@@ -1038,7 +1039,7 @@ def extract_topics(in_data_file: GradioFileData,
1038
 
1039
  whole_conversation = list()
1040
 
1041
- responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, google_client, google_config, model_choice, temperature, reported_batch_no, local_model, tokenizer,bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
1042
 
1043
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=True, output_folder=output_folder)
1044
 
@@ -1243,6 +1244,7 @@ def wrapper_extract_topics_per_column_value(
1243
  aws_secret_key_textbox:str="",
1244
  hf_api_key_textbox:str="",
1245
  azure_api_key_textbox:str="",
 
1246
  output_folder: str = OUTPUT_FOLDER,
1247
  existing_logged_content:list=list(),
1248
  additional_instructions_summary_format:str="",
@@ -1296,7 +1298,7 @@ def wrapper_extract_topics_per_column_value(
1296
  :param aws_access_key_textbox: AWS access key for Bedrock.
1297
  :param aws_secret_key_textbox: AWS secret key for Bedrock.
1298
  :param hf_api_key_textbox: Hugging Face API key for local models.
1299
- :param azure_api_key_textbox: Azure API key for Azure AI Inference.
1300
  :param output_folder: The folder where output files will be saved.
1301
  :param existing_logged_content: A list of existing logged content.
1302
  :param force_single_topic_prompt: Prompt for forcing a single topic.
@@ -1475,6 +1477,7 @@ def wrapper_extract_topics_per_column_value(
1475
  aws_secret_key_textbox=aws_secret_key_textbox,
1476
  hf_api_key_textbox=hf_api_key_textbox,
1477
  azure_api_key_textbox=azure_api_key_textbox,
 
1478
  max_tokens=max_tokens,
1479
  model_name_map=model_name_map,
1480
  max_time_for_loop=max_time_for_loop,
@@ -1736,6 +1739,7 @@ def all_in_one_pipeline(
1736
  aws_secret_key_text: str,
1737
  hf_api_key_text: str,
1738
  azure_api_key_text: str,
 
1739
  output_folder: str = OUTPUT_FOLDER,
1740
  merge_sentiment: str = "No",
1741
  merge_general_topics: str = "Yes",
@@ -1790,7 +1794,7 @@ def all_in_one_pipeline(
1790
  aws_access_key_text (str): AWS access key.
1791
  aws_secret_key_text (str): AWS secret key.
1792
  hf_api_key_text (str): Hugging Face API key.
1793
- azure_api_key_text (str): Azure API key.
1794
  output_folder (str, optional): Folder to save output files. Defaults to OUTPUT_FOLDER.
1795
  merge_sentiment (str, optional): Whether to merge sentiment. Defaults to "No".
1796
  merge_general_topics (str, optional): Whether to merge general topics. Defaults to "Yes".
@@ -1884,6 +1888,7 @@ def all_in_one_pipeline(
1884
  aws_secret_key_textbox=aws_secret_key_text,
1885
  hf_api_key_textbox=hf_api_key_text,
1886
  azure_api_key_textbox=azure_api_key_text,
 
1887
  output_folder=output_folder,
1888
  existing_logged_content=existing_logged_content,
1889
  model_name_map=model_name_map_state,
 
18
  from tools.prompts import initial_table_prompt, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt, default_response_reference_format, negative_neutral_positive_sentiment_prompt, negative_or_positive_sentiment_prompt, default_sentiment_prompt
19
  from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details, move_overall_summary_output_files_to_front_page, generate_zero_shot_topics_df
20
  from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
21
+ from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, LLM_MAX_NEW_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_OPENAI_INFERENCE_ENDPOINT, MAX_ROWS, MAXIMUM_ZERO_SHOT_TOPICS, MAX_SPACES_GPU_RUN_TIME, OUTPUT_DEBUG_FILES
22
  from tools.aws_functions import connect_to_bedrock_runtime
23
  from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary, process_debug_output_iteration
24
  from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
 
620
  aws_secret_key_textbox:str='',
621
  hf_api_key_textbox:str='',
622
  azure_api_key_textbox:str='',
623
+ azure_endpoint_textbox:str='',
624
  max_tokens:int=max_tokens,
625
  model_name_map:dict=model_name_map,
626
  existing_logged_content:list=list(),
 
636
  progress=Progress(track_tqdm=False)):
637
 
638
  '''
639
+ Query an LLM (local, (Gemma/GPT-OSS if local, Gemini, AWS Bedrock or Azure/OpenAI AI Inference) with up to three prompts about a table of open text data. Up to 'batch_size' rows will be queried at a time.
640
 
641
  Parameters:
642
  - in_data_file (gr.File): Gradio file object containing input data
 
694
 
695
  tic = time.perf_counter()
696
 
697
+ client = list()
698
+ client_config = {}
699
  final_time = 0.0
700
  whole_conversation_metadata = list()
701
  is_error = False
 
823
  # Prepare clients before query
824
  if "Gemini" in model_source:
825
  #print("Using Gemini model:", model_choice)
826
+ client, client_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
827
+ elif "Azure/OpenAI" in model_source:
828
+ #print("Using Azure/OpenAI AI Inference model:", model_choice)
829
  # If provided, set env for downstream calls too
830
  if azure_api_key_textbox:
831
  os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
832
+ client, client_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=azure_endpoint_textbox)
833
  elif "anthropic.claude" in model_choice:
834
  #print("Using AWS Bedrock model:", model_choice)
835
  pass
 
950
  whole_conversation = list()
951
 
952
  # Process requests to large language model
953
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, client, client_config, model_choice, temperature, reported_batch_no, local_model, tokenizer, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
954
 
955
  # Return output tables
956
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=False, output_folder=output_folder)
 
1012
  # Prepare Gemini models before query
1013
  if model_source == "Gemini":
1014
  print("Using Gemini model:", model_choice)
1015
+ client, client_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
1016
+ elif model_source == "Azure/OpenAI":
1017
+ print("Using Azure/OpenAI AI Inference model:", model_choice)
1018
  if azure_api_key_textbox:
1019
  os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
1020
+ client, client_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=azure_endpoint_textbox)
1021
  elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
1022
  pass
1023
  #print("Using local model:", model_choice)
 
1039
 
1040
  whole_conversation = list()
1041
 
1042
+ responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, client, client_config, model_choice, temperature, reported_batch_no, local_model, tokenizer,bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
1043
 
1044
  topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=True, output_folder=output_folder)
1045
 
 
1244
  aws_secret_key_textbox:str="",
1245
  hf_api_key_textbox:str="",
1246
  azure_api_key_textbox:str="",
1247
+ azure_endpoint_textbox:str="",
1248
  output_folder: str = OUTPUT_FOLDER,
1249
  existing_logged_content:list=list(),
1250
  additional_instructions_summary_format:str="",
 
1298
  :param aws_access_key_textbox: AWS access key for Bedrock.
1299
  :param aws_secret_key_textbox: AWS secret key for Bedrock.
1300
  :param hf_api_key_textbox: Hugging Face API key for local models.
1301
+ :param azure_api_key_textbox: Azure/OpenAI API key for Azure/OpenAI AI Inference.
1302
  :param output_folder: The folder where output files will be saved.
1303
  :param existing_logged_content: A list of existing logged content.
1304
  :param force_single_topic_prompt: Prompt for forcing a single topic.
 
1477
  aws_secret_key_textbox=aws_secret_key_textbox,
1478
  hf_api_key_textbox=hf_api_key_textbox,
1479
  azure_api_key_textbox=azure_api_key_textbox,
1480
+ azure_endpoint_textbox=azure_endpoint_textbox,
1481
  max_tokens=max_tokens,
1482
  model_name_map=model_name_map,
1483
  max_time_for_loop=max_time_for_loop,
 
1739
  aws_secret_key_text: str,
1740
  hf_api_key_text: str,
1741
  azure_api_key_text: str,
1742
+ azure_endpoint_text: str,
1743
  output_folder: str = OUTPUT_FOLDER,
1744
  merge_sentiment: str = "No",
1745
  merge_general_topics: str = "Yes",
 
1794
  aws_access_key_text (str): AWS access key.
1795
  aws_secret_key_text (str): AWS secret key.
1796
  hf_api_key_text (str): Hugging Face API key.
1797
+ azure_api_key_text (str): Azure/OpenAI API key.
1798
  output_folder (str, optional): Folder to save output files. Defaults to OUTPUT_FOLDER.
1799
  merge_sentiment (str, optional): Whether to merge sentiment. Defaults to "No".
1800
  merge_general_topics (str, optional): Whether to merge general topics. Defaults to "Yes".
 
1888
  aws_secret_key_textbox=aws_secret_key_text,
1889
  hf_api_key_textbox=hf_api_key_text,
1890
  azure_api_key_textbox=azure_api_key_text,
1891
+ azure_endpoint_textbox=azure_endpoint_text,
1892
  output_folder=output_folder,
1893
  existing_logged_content=existing_logged_content,
1894
  model_name_map=model_name_map_state,
tools/llm_funcs.py CHANGED
@@ -10,10 +10,7 @@ from typing import List, Tuple, TypeVar
10
  from google import genai as ai
11
  from google.genai import types
12
  from gradio import Progress
13
-
14
- from azure.ai.inference import ChatCompletionsClient
15
- from azure.core.credentials import AzureKeyCredential
16
- from azure.ai.inference.models import SystemMessage, UserMessage
17
 
18
  model_type = None # global variable setup
19
  full_text = "" # Define dummy source text (full text) just to enable highlight function to load
@@ -674,27 +671,31 @@ def construct_gemini_generative_model(in_api_key: str, temperature: float, model
674
 
675
  def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
676
  """
677
- Constructs a ChatCompletionsClient for Azure AI Inference.
678
  """
679
  try:
680
  key = None
681
  if in_api_key:
682
  key = in_api_key
683
- elif os.environ.get("AZURE_INFERENCE_CREDENTIAL"):
684
- key = os.environ["AZURE_INFERENCE_CREDENTIAL"]
685
- elif os.environ.get("AZURE_API_KEY"):
686
- key = os.environ["AZURE_API_KEY"]
687
  if not key:
688
- raise Warning("No Azure API key found.")
689
 
690
  if not endpoint:
691
- endpoint = os.environ.get("AZURE_INFERENCE_ENDPOINT", "")
692
  if not endpoint:
693
- raise Warning("No Azure inference endpoint found.")
694
- client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(key))
695
- return client, {}
 
 
 
 
 
 
696
  except Exception as e:
697
- print("Error constructing Azure ChatCompletions client:", e)
698
  raise
699
 
700
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
@@ -756,7 +757,15 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
756
  )
757
 
758
  output_message = api_response['output']['message']
759
- text = assistant_prefill + output_message['content'][0]['text']
 
 
 
 
 
 
 
 
760
 
761
  # The usage statistics are neatly provided in the 'usage' key.
762
  usage = api_response['usage']
@@ -803,9 +812,6 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
803
  {"role": "system", "content": system_prompt},
804
  {"role": "user", "content": prompt}
805
  ]
806
- #print("Conversation:", conversation)
807
- #import pprint
808
- #pprint.pprint(conversation)
809
 
810
  # 2. Apply the chat template
811
  # This function formats the conversation into the exact string Gemma 3 expects.
@@ -820,9 +826,6 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
820
  ).to("cuda")
821
  except Exception as e:
822
  print("Error applying chat template:", e)
823
- print("Conversation type:", type(conversation))
824
- for turn in conversation:
825
- print("Turn type:", type(turn), "Content type:", type(turn.get("content")))
826
  raise
827
 
828
  # Map LlamaCPP parameters to transformers parameters
@@ -850,7 +853,7 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
850
 
851
  # Use speculative decoding if assistant model is available
852
  if speculative_decoding and assistant_model is not None:
853
- print("Using speculative decoding with assistant model")
854
  outputs = model.generate(
855
  input_ids,
856
  assistant_model=assistant_model,
@@ -858,7 +861,7 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
858
  streamer = streamer
859
  )
860
  else:
861
- print("Generating without speculative decoding")
862
  outputs = model.generate(
863
  input_ids,
864
  **generation_kwargs,
@@ -868,11 +871,9 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
868
  end_time = time.time()
869
 
870
  # --- Decode and Display Results ---
871
- #generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
872
- # To get only the model's reply, we can decode just the newly generated tokens
873
  new_tokens = outputs[0][input_ids.shape[-1]:]
874
  assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
875
- #print("Assistant reply:", assistant_reply)
876
 
877
  num_input_tokens = input_ids.shape[-1] # This gets the sequence length (number of tokens)
878
  num_generated_tokens = len(new_tokens)
@@ -887,12 +888,32 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
887
  return assistant_reply, num_input_tokens, num_generated_tokens
888
 
889
  # Function to send a request and update history
890
- def send_request(prompt: str, conversation_history: List[dict], google_client: ai.Client, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), tokenizer=None, assistant_model=None, assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
891
- """
892
- This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
893
- It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
894
- If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `client.models.generate_content` method.
895
- The function returns the response text and the updated conversation history.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
  """
897
  # Constructing the full prompt from the conversation history
898
  full_prompt = "Conversation history:\n"
@@ -920,7 +941,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
920
  try:
921
  print("Calling Gemini model, attempt", i + 1)
922
 
923
- response = google_client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
924
 
925
  #print("Successful call to Gemini model.")
926
  break
@@ -948,18 +969,29 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
948
 
949
  if i == number_of_api_retry_attempts:
950
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
951
- elif "Azure" in model_source:
952
  for i in progress_bar:
953
  try:
954
- print("Calling Azure AI Inference model, attempt", i + 1)
955
- # Use structured messages for Azure
956
- response_raw = google_client.complete(
957
- messages=[
958
- SystemMessage(content=system_prompt),
959
- UserMessage(content=prompt),
960
- ],
961
- model=model_choice
 
 
 
 
 
 
 
 
 
 
962
  )
 
963
  response_text = response_raw.choices[0].message.content
964
  usage = getattr(response_raw, "usage", None)
965
  input_tokens = 0
@@ -973,7 +1005,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
973
  )
974
  break
975
  except Exception as e:
976
- print("Call to Azure model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
977
  time.sleep(timeout_wait)
978
  if i == number_of_api_retry_attempts:
979
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
@@ -993,7 +1025,6 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
993
  response, num_transformer_input_tokens, num_transformer_generated_tokens = call_transformers_model(prompt, system_prompt, gen_config, model=local_model, tokenizer=tokenizer, assistant_model=assistant_model)
994
  response_text = response
995
 
996
- #print("Successful call to local model.")
997
  break
998
  except Exception as e:
999
  # If fails, try again after X seconds in case there is a throttle limit
@@ -1035,7 +1066,7 @@ system_prompt: str,
1035
  conversation_history: List[dict],
1036
  whole_conversation: List[str],
1037
  whole_conversation_metadata: List[str],
1038
- google_client: ai.Client,
1039
  config: types.GenerateContentConfig,
1040
  model_choice: str,
1041
  temperature: float,
@@ -1056,7 +1087,7 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
1056
  conversation_history (List[dict]): The history of the conversation.
1057
  whole_conversation (List[str]): The complete conversation including prompts and responses.
1058
  whole_conversation_metadata (List[str]): Metadata about the whole conversation.
1059
- google_client (object): The google_client to use for processing the prompts.
1060
  config (dict): Configuration for the model.
1061
  model_choice (str): The choice of model to use.
1062
  temperature (float): The temperature parameter for the model.
@@ -1077,22 +1108,15 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
1077
 
1078
  for prompt in prompts:
1079
 
1080
- response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens = send_request(prompt, conversation_history, google_client=google_client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, tokenizer=tokenizer, assistant_model=assistant_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
1081
 
1082
  responses.append(response)
1083
  whole_conversation.append(system_prompt)
1084
  whole_conversation.append(prompt)
1085
  whole_conversation.append(response_text)
1086
 
1087
- # Create conversation metadata
1088
- # if master == False:
1089
- # whole_conversation_metadata.append(f"Batch {batch_no}:")
1090
- # else:
1091
- # #whole_conversation_metadata.append(f"Query summary metadata:")
1092
-
1093
  whole_conversation_metadata.append(f"Batch {batch_no}:")
1094
 
1095
- # if not isinstance(response, str):
1096
  try:
1097
  if "AWS" in model_source:
1098
  output_tokens = response.usage_metadata.get('outputTokens', 0)
@@ -1102,7 +1126,7 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
1102
  output_tokens = response.usage_metadata.candidates_token_count
1103
  input_tokens = response.usage_metadata.prompt_token_count
1104
 
1105
- elif "Azure" in model_source:
1106
  input_tokens = response.usage_metadata.get('inputTokens', 0)
1107
  output_tokens = response.usage_metadata.get('outputTokens', 0)
1108
 
@@ -1123,9 +1147,6 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
1123
 
1124
  except KeyError as e:
1125
  print(f"Key error: {e} - Check the structure of response.usage_metadata")
1126
- # else:
1127
- # print("Response is a string object.")
1128
- # whole_conversation_metadata.append("Length prompt: " + str(len(prompt)) + ". Length response: " + str(len(response)))
1129
 
1130
  return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
1131
 
@@ -1134,8 +1155,8 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
1134
  conversation_history: List[dict],
1135
  whole_conversation: List[str],
1136
  whole_conversation_metadata: List[str],
1137
- google_client: ai.Client,
1138
- google_config: types.GenerateContentConfig,
1139
  model_choice: str,
1140
  temperature: float,
1141
  reported_batch_no: int,
@@ -1157,8 +1178,8 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
1157
  - conversation_history (List[dict]): The history of the conversation.
1158
  - whole_conversation (List[str]): The complete conversation including prompts and responses.
1159
  - whole_conversation_metadata (List[str]): Metadata about the whole conversation.
1160
- - google_client (ai.Client): The Google client object for running Gemini API calls.
1161
- - google_config (types.GenerateContentConfig): Configuration for the model.
1162
  - model_choice (str): The choice of model to use.
1163
  - temperature (float): The temperature parameter for the model.
1164
  - reported_batch_no (int): The reported batch number.
@@ -1179,13 +1200,13 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
1179
  call_temperature = temperature # This is correct now with the fixed parameter name
1180
 
1181
  # Update Gemini config with the new temperature settings
1182
- google_config = types.GenerateContentConfig(temperature=call_temperature, max_output_tokens=max_tokens, seed=random_seed)
1183
 
1184
  for attempt in range(MAX_OUTPUT_VALIDATION_ATTEMPTS):
1185
  # Process requests to large language model
1186
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
1187
  batch_prompts, system_prompt, conversation_history, whole_conversation,
1188
- whole_conversation_metadata, google_client, google_config, model_choice,
1189
  call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, tokenizer=tokenizer, master=master, assistant_prefill=assistant_prefill
1190
  )
1191
 
 
10
  from google import genai as ai
11
  from google.genai import types
12
  from gradio import Progress
13
+ from openai import OpenAI
 
 
 
14
 
15
  model_type = None # global variable setup
16
  full_text = "" # Define dummy source text (full text) just to enable highlight function to load
 
671
 
672
  def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
673
  """
674
+ Constructs an OpenAI client for Azure/OpenAI AI Inference.
675
  """
676
  try:
677
  key = None
678
  if in_api_key:
679
  key = in_api_key
680
+ elif os.environ.get("AZURE_OPENAI_API_KEY"):
681
+ key = os.environ["AZURE_OPENAI_API_KEY"]
 
 
682
  if not key:
683
+ raise Warning("No Azure/OpenAI API key found.")
684
 
685
  if not endpoint:
686
+ endpoint = os.environ.get("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
687
  if not endpoint:
688
+ raise Warning("No Azure/OpenAI inference endpoint found.")
689
+
690
+ # Use the provided endpoint instead of hardcoded value
691
+ client = OpenAI(
692
+ api_key=key,
693
+ base_url=f"{endpoint}",
694
+ )
695
+
696
+ return client, dict()
697
  except Exception as e:
698
+ print("Error constructing Azure/OpenAI client:", e)
699
  raise
700
 
701
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
 
757
  )
758
 
759
  output_message = api_response['output']['message']
760
+
761
+ if 'reasoningContent' in output_message['content'][0]:
762
+ # Extract the reasoning text
763
+ reasoning_text = output_message['content'][0]['reasoningContent']['reasoningText']['text']
764
+
765
+ # Extract the output text
766
+ text = assistant_prefill + output_message['content'][1]['text']
767
+ else:
768
+ text = assistant_prefill + output_message['content'][0]['text']
769
 
770
  # The usage statistics are neatly provided in the 'usage' key.
771
  usage = api_response['usage']
 
812
  {"role": "system", "content": system_prompt},
813
  {"role": "user", "content": prompt}
814
  ]
 
 
 
815
 
816
  # 2. Apply the chat template
817
  # This function formats the conversation into the exact string Gemma 3 expects.
 
826
  ).to("cuda")
827
  except Exception as e:
828
  print("Error applying chat template:", e)
 
 
 
829
  raise
830
 
831
  # Map LlamaCPP parameters to transformers parameters
 
853
 
854
  # Use speculative decoding if assistant model is available
855
  if speculative_decoding and assistant_model is not None:
856
+ #print("Using speculative decoding with assistant model")
857
  outputs = model.generate(
858
  input_ids,
859
  assistant_model=assistant_model,
 
861
  streamer = streamer
862
  )
863
  else:
864
+ #print("Generating without speculative decoding")
865
  outputs = model.generate(
866
  input_ids,
867
  **generation_kwargs,
 
871
  end_time = time.time()
872
 
873
  # --- Decode and Display Results ---
 
 
874
  new_tokens = outputs[0][input_ids.shape[-1]:]
875
  assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
876
+
877
 
878
  num_input_tokens = input_ids.shape[-1] # This gets the sequence length (number of tokens)
879
  num_generated_tokens = len(new_tokens)
 
888
  return assistant_reply, num_input_tokens, num_generated_tokens
889
 
890
  # Function to send a request and update history
891
+ def send_request(prompt: str, conversation_history: List[dict], client: ai.Client | OpenAI, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), tokenizer=None, assistant_model=None, assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
892
+ """Sends a request to a language model and manages the conversation history.
893
+
894
+ This function constructs the full prompt by appending the new user prompt to the conversation history,
895
+ generates a response from the model, and updates the conversation history with the new prompt and response.
896
+ It handles different model sources (Gemini, AWS, Local) and includes retry logic for API calls.
897
+
898
+ Args:
899
+ prompt (str): The user's input prompt to be sent to the model.
900
+ conversation_history (List[dict]): A list of dictionaries representing the ongoing conversation.
901
+ Each dictionary should have 'role' and 'parts' keys.
902
+ client (ai.Client): The API client object for the chosen model (e.g., Gemini `ai.Client`, or Azure/OpenAI `OpenAI`).
903
+ config (types.GenerateContentConfig): Configuration settings for content generation (e.g., Gemini `types.GenerateContentConfig`).
904
+ model_choice (str): The specific model identifier to use (e.g., "gemini-pro", "claude-v2").
905
+ system_prompt (str): An optional system-level instruction or context for the model.
906
+ temperature (float): Controls the randomness of the model's output, with higher values leading to more diverse responses.
907
+ bedrock_runtime (boto3.Session.client): The boto3 Bedrock runtime client object for AWS models.
908
+ model_source (str): Indicates the source/provider of the model (e.g., "Gemini", "AWS", "Local").
909
+ local_model (list, optional): A list containing the local model and its tokenizer (if `model_source` is "Local"). Defaults to [].
910
+ tokenizer (object, optional): The tokenizer object for local models. Defaults to None.
911
+ assistant_model (object, optional): An optional assistant model used for speculative decoding with local models. Defaults to None.
912
+ assistant_prefill (str, optional): A string to pre-fill the assistant's response, useful for certain models like Claude. Defaults to "".
913
+ progress (Progress, optional): A progress object for tracking the operation, typically from `tqdm`. Defaults to Progress(track_tqdm=True).
914
+
915
+ Returns:
916
+ Tuple[str, List[dict]]: A tuple containing the model's response text and the updated conversation history.
917
  """
918
  # Constructing the full prompt from the conversation history
919
  full_prompt = "Conversation history:\n"
 
941
  try:
942
  print("Calling Gemini model, attempt", i + 1)
943
 
944
+ response = client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
945
 
946
  #print("Successful call to Gemini model.")
947
  break
 
969
 
970
  if i == number_of_api_retry_attempts:
971
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
972
+ elif "Azure/OpenAI" in model_source:
973
  for i in progress_bar:
974
  try:
975
+ print("Calling Azure/OpenAI inference model, attempt", i + 1)
976
+
977
+ messages=[
978
+ {
979
+ "role": "system",
980
+ "content": system_prompt,
981
+ },
982
+ {
983
+ "role": "user",
984
+ "content": prompt,
985
+ },
986
+ ]
987
+
988
+ response_raw = client.chat.completions.create(
989
+ messages=messages,
990
+ model=model_choice,
991
+ temperature=temperature,
992
+ max_completion_tokens=max_tokens
993
  )
994
+
995
  response_text = response_raw.choices[0].message.content
996
  usage = getattr(response_raw, "usage", None)
997
  input_tokens = 0
 
1005
  )
1006
  break
1007
  except Exception as e:
1008
+ print("Call to Azure/OpenAI model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
1009
  time.sleep(timeout_wait)
1010
  if i == number_of_api_retry_attempts:
1011
  return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
 
1025
  response, num_transformer_input_tokens, num_transformer_generated_tokens = call_transformers_model(prompt, system_prompt, gen_config, model=local_model, tokenizer=tokenizer, assistant_model=assistant_model)
1026
  response_text = response
1027
 
 
1028
  break
1029
  except Exception as e:
1030
  # If fails, try again after X seconds in case there is a throttle limit
 
1066
  conversation_history: List[dict],
1067
  whole_conversation: List[str],
1068
  whole_conversation_metadata: List[str],
1069
+ client: ai.Client | OpenAI,
1070
  config: types.GenerateContentConfig,
1071
  model_choice: str,
1072
  temperature: float,
 
1087
  conversation_history (List[dict]): The history of the conversation.
1088
  whole_conversation (List[str]): The complete conversation including prompts and responses.
1089
  whole_conversation_metadata (List[str]): Metadata about the whole conversation.
1090
+ client (object): The client to use for processing the prompts, from either Gemini or OpenAI client.
1091
  config (dict): Configuration for the model.
1092
  model_choice (str): The choice of model to use.
1093
  temperature (float): The temperature parameter for the model.
 
1108
 
1109
  for prompt in prompts:
1110
 
1111
+ response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens = send_request(prompt, conversation_history, client=client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, tokenizer=tokenizer, assistant_model=assistant_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
1112
 
1113
  responses.append(response)
1114
  whole_conversation.append(system_prompt)
1115
  whole_conversation.append(prompt)
1116
  whole_conversation.append(response_text)
1117
 
 
 
 
 
 
 
1118
  whole_conversation_metadata.append(f"Batch {batch_no}:")
1119
 
 
1120
  try:
1121
  if "AWS" in model_source:
1122
  output_tokens = response.usage_metadata.get('outputTokens', 0)
 
1126
  output_tokens = response.usage_metadata.candidates_token_count
1127
  input_tokens = response.usage_metadata.prompt_token_count
1128
 
1129
+ elif "Azure/OpenAI" in model_source:
1130
  input_tokens = response.usage_metadata.get('inputTokens', 0)
1131
  output_tokens = response.usage_metadata.get('outputTokens', 0)
1132
 
 
1147
 
1148
  except KeyError as e:
1149
  print(f"Key error: {e} - Check the structure of response.usage_metadata")
 
 
 
1150
 
1151
  return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
1152
 
 
1155
  conversation_history: List[dict],
1156
  whole_conversation: List[str],
1157
  whole_conversation_metadata: List[str],
1158
+ client: ai.Client | OpenAI,
1159
+ client_config: types.GenerateContentConfig,
1160
  model_choice: str,
1161
  temperature: float,
1162
  reported_batch_no: int,
 
1178
  - conversation_history (List[dict]): The history of the conversation.
1179
  - whole_conversation (List[str]): The complete conversation including prompts and responses.
1180
  - whole_conversation_metadata (List[str]): Metadata about the whole conversation.
1181
+ - client (ai.Client | OpenAI): The client object for running Gemini or Azure/OpenAI API calls.
1182
+ - client_config (types.GenerateContentConfig): Configuration for the model.
1183
  - model_choice (str): The choice of model to use.
1184
  - temperature (float): The temperature parameter for the model.
1185
  - reported_batch_no (int): The reported batch number.
 
1200
  call_temperature = temperature # This is correct now with the fixed parameter name
1201
 
1202
  # Update Gemini config with the new temperature settings
1203
+ client_config = types.GenerateContentConfig(temperature=call_temperature, max_output_tokens=max_tokens, seed=random_seed)
1204
 
1205
  for attempt in range(MAX_OUTPUT_VALIDATION_ATTEMPTS):
1206
  # Process requests to large language model
1207
  responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
1208
  batch_prompts, system_prompt, conversation_history, whole_conversation,
1209
+ whole_conversation_metadata, client, client_config, model_choice,
1210
  call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, tokenizer=tokenizer, master=master, assistant_prefill=assistant_prefill
1211
  )
1212