Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3085585
1
Parent(s):
5fed34d
Added model compatibility for OpenAI and Azure endpoints. Added some Bedrock models, now compatible with thinking models
Browse files- app.py +17 -40
- pyproject.toml +1 -2
- requirements.txt +1 -2
- requirements_cpu.txt +1 -2
- requirements_gpu.txt +1 -2
- requirements_no_local.txt +1 -2
- tools/config.py +78 -15
- tools/dedup_summaries.py +22 -20
- tools/helper_functions.py +12 -4
- tools/llm_api_call.py +21 -16
- tools/llm_funcs.py +86 -65
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import os
|
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
from datetime import datetime
|
| 6 |
-
from tools.helper_functions import put_columns_in_df, get_connection_params, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df, load_in_previous_data_files, load_in_data_file, load_in_default_cost_codes, reset_base_dataframe, update_cost_code_dataframe_from_dropdown_select, df_select_callback_cost, enforce_cost_codes, _get_env_list, move_overall_summary_output_files_to_front_page
|
| 7 |
from tools.aws_functions import upload_file_to_s3, download_file_from_s3
|
| 8 |
from tools.llm_api_call import modify_existing_output_tables, wrapper_extract_topics_per_column_value, all_in_one_pipeline
|
| 9 |
from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, deduplicate_topics_llm, overall_summary
|
|
@@ -13,18 +13,7 @@ from tools.auth import authenticate_user
|
|
| 13 |
from tools.example_table_outputs import dummy_consultation_table, case_notes_table, dummy_consultation_table_zero_shot, case_notes_table_grouped, case_notes_table_structured_summary
|
| 14 |
from tools.prompts import initial_table_prompt, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
| 15 |
# from tools.verify_titles import verify_titles
|
| 16 |
-
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER,
|
| 17 |
-
|
| 18 |
-
def ensure_folder_exists(output_folder:str):
|
| 19 |
-
"""Checks if the specified folder exists, creates it if not."""
|
| 20 |
-
|
| 21 |
-
if not os.path.exists(output_folder):
|
| 22 |
-
# Create the folder if it doesn't exist
|
| 23 |
-
os.makedirs(output_folder, exist_ok=True)
|
| 24 |
-
print(f"Created the {output_folder} folder.")
|
| 25 |
-
else:
|
| 26 |
-
pass
|
| 27 |
-
#print(f"The {output_folder} folder already exists.")
|
| 28 |
|
| 29 |
ensure_folder_exists(CONFIG_FOLDER)
|
| 30 |
ensure_folder_exists(OUTPUT_FOLDER)
|
|
@@ -35,26 +24,8 @@ ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
|
|
| 35 |
ensure_folder_exists(ACCESS_LOGS_FOLDER)
|
| 36 |
ensure_folder_exists(USAGE_LOGS_FOLDER)
|
| 37 |
|
| 38 |
-
# Convert string environment variables to string or list
|
| 39 |
-
if SAVE_LOGS_TO_CSV == "True": SAVE_LOGS_TO_CSV = True
|
| 40 |
-
else: SAVE_LOGS_TO_CSV = False
|
| 41 |
-
if SAVE_LOGS_TO_DYNAMODB == "True": SAVE_LOGS_TO_DYNAMODB = True
|
| 42 |
-
else: SAVE_LOGS_TO_DYNAMODB = False
|
| 43 |
-
|
| 44 |
-
if CSV_ACCESS_LOG_HEADERS: CSV_ACCESS_LOG_HEADERS = _get_env_list(CSV_ACCESS_LOG_HEADERS)
|
| 45 |
-
if CSV_FEEDBACK_LOG_HEADERS: CSV_FEEDBACK_LOG_HEADERS = _get_env_list(CSV_FEEDBACK_LOG_HEADERS)
|
| 46 |
-
if CSV_USAGE_LOG_HEADERS: CSV_USAGE_LOG_HEADERS = _get_env_list(CSV_USAGE_LOG_HEADERS)
|
| 47 |
-
|
| 48 |
-
if DYNAMODB_ACCESS_LOG_HEADERS: DYNAMODB_ACCESS_LOG_HEADERS = _get_env_list(DYNAMODB_ACCESS_LOG_HEADERS)
|
| 49 |
-
if DYNAMODB_FEEDBACK_LOG_HEADERS: DYNAMODB_FEEDBACK_LOG_HEADERS = _get_env_list(DYNAMODB_FEEDBACK_LOG_HEADERS)
|
| 50 |
-
if DYNAMODB_USAGE_LOG_HEADERS: DYNAMODB_USAGE_LOG_HEADERS = _get_env_list(DYNAMODB_USAGE_LOG_HEADERS)
|
| 51 |
-
|
| 52 |
today_rev = datetime.now().strftime("%Y%m%d")
|
| 53 |
|
| 54 |
-
if RUN_LOCAL_MODEL == "1": default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
|
| 55 |
-
elif RUN_AWS_FUNCTIONS == "1": default_model_choice = "anthropic.claude-3-haiku-20240307-v1:0"
|
| 56 |
-
else: default_model_choice = "gemini-2.5-flash"
|
| 57 |
-
|
| 58 |
# Placeholders for example variables
|
| 59 |
in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
|
| 60 |
in_colnames = gr.Dropdown(choices=[""], multiselect = False, label="Select the open text column of interest. In an Excel file, this shows columns across all sheets.", allow_custom_value=True, interactive=True)
|
|
@@ -162,7 +133,7 @@ with app:
|
|
| 162 |
|
| 163 |
gr.Markdown("""# Large language model topic modelling
|
| 164 |
|
| 165 |
-
Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets below. API keys for AWS, Azure, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
|
| 166 |
|
| 167 |
NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
|
| 168 |
|
|
@@ -198,7 +169,10 @@ with app:
|
|
| 198 |
|
| 199 |
with gr.Tab(label="All in one topic extraction and summarisation"):
|
| 200 |
with gr.Row():
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
with gr.Accordion("Upload xlsx, csv, or parquet file", open = True):
|
| 204 |
in_data_files.render()
|
|
@@ -339,8 +313,9 @@ with app:
|
|
| 339 |
with gr.Accordion("Gemini API keys", open = False):
|
| 340 |
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
| 341 |
|
| 342 |
-
with gr.Accordion("Azure
|
| 343 |
-
azure_api_key_textbox = gr.Textbox(value =
|
|
|
|
| 344 |
|
| 345 |
with gr.Accordion("Hugging Face token for downloading gated models", open = False):
|
| 346 |
hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only for gated models that need a token to download)", lines=1, type="password")
|
|
@@ -435,6 +410,7 @@ with app:
|
|
| 435 |
aws_secret_key_textbox,
|
| 436 |
hf_api_key_textbox,
|
| 437 |
azure_api_key_textbox,
|
|
|
|
| 438 |
output_folder_state,
|
| 439 |
logged_content_df,
|
| 440 |
add_existing_topics_summary_format_textbox],
|
|
@@ -478,12 +454,12 @@ with app:
|
|
| 478 |
success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames, output_folder_state], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown], scroll_to_output=True, api_name="deduplicate_topics")
|
| 479 |
|
| 480 |
# When LLM deduplication button pressed, deduplicate data using LLM
|
| 481 |
-
def deduplicate_topics_llm_wrapper(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics=None):
|
| 482 |
model_source = model_name_map[model_choice]["source"]
|
| 483 |
-
return deduplicate_topics_llm(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, model_source, None, None, None, None, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics)
|
| 484 |
|
| 485 |
deduplicate_llm_previous_data_btn.click(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
| 486 |
-
success(deduplicate_topics_llm_wrapper, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, model_choice, google_api_key_textbox, temperature_slide, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, in_data_files, in_colnames, output_folder_state, candidate_topics], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="deduplicate_topics_llm").\
|
| 487 |
success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs_llm_dedup")
|
| 488 |
|
| 489 |
# When button pressed, summarise previous data
|
|
@@ -491,14 +467,14 @@ with app:
|
|
| 491 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 492 |
success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
| 493 |
success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
|
| 494 |
-
success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, logged_content_df], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], api_name="summarise_topics", show_progress_on=[output_messages_textbox, summary_output_files]).\
|
| 495 |
success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
|
| 496 |
then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_revised_summaries_state, master_unique_topics_df_revised_summaries_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[summary_output_files_xlsx, summary_xlsx_output_files_list])
|
| 497 |
|
| 498 |
# SUMMARISE WHOLE TABLE PAGE
|
| 499 |
overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 500 |
success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
| 501 |
-
success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, logged_content_df], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], scroll_to_output=True, api_name="overall_summary", show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
|
| 502 |
success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
|
| 503 |
then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list])
|
| 504 |
|
|
@@ -545,6 +521,7 @@ with app:
|
|
| 545 |
aws_secret_key_textbox,
|
| 546 |
hf_api_key_textbox,
|
| 547 |
azure_api_key_textbox,
|
|
|
|
| 548 |
output_folder_state,
|
| 549 |
merge_sentiment_drop,
|
| 550 |
merge_general_topics_drop,
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
from datetime import datetime
|
| 6 |
+
from tools.helper_functions import put_columns_in_df, get_connection_params, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df, load_in_previous_data_files, load_in_data_file, load_in_default_cost_codes, reset_base_dataframe, update_cost_code_dataframe_from_dropdown_select, df_select_callback_cost, enforce_cost_codes, _get_env_list, move_overall_summary_output_files_to_front_page, update_model_choice
|
| 7 |
from tools.aws_functions import upload_file_to_s3, download_file_from_s3
|
| 8 |
from tools.llm_api_call import modify_existing_output_tables, wrapper_extract_topics_per_column_value, all_in_one_pipeline
|
| 9 |
from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, deduplicate_topics_llm, overall_summary
|
|
|
|
| 13 |
from tools.example_table_outputs import dummy_consultation_table, case_notes_table, dummy_consultation_table_zero_shot, case_notes_table_grouped, case_notes_table_structured_summary
|
| 14 |
from tools.prompts import initial_table_prompt, system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, two_para_summary_format_prompt, single_para_summary_format_prompt
|
| 15 |
# from tools.verify_titles import verify_titles
|
| 16 |
+
from tools.config import RUN_AWS_FUNCTIONS, HOST_NAME, ACCESS_LOGS_FOLDER, FEEDBACK_LOGS_FOLDER, USAGE_LOGS_FOLDER, FILE_INPUT_HEIGHT, GEMINI_API_KEY, BATCH_SIZE_DEFAULT, LLM_SEED, COGNITO_AUTH, MAX_QUEUE_SIZE, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, INPUT_FOLDER, OUTPUT_FOLDER, S3_LOG_BUCKET, CONFIG_FOLDER, GRADIO_TEMP_DIR, MPLCONFIGDIR, GET_COST_CODES, ENFORCE_COST_CODES, DEFAULT_COST_CODE, COST_CODES_PATH, S3_COST_CODES_PATH, OUTPUT_COST_CODES_PATH, SHOW_COSTS, SAVE_LOGS_TO_CSV, SAVE_LOGS_TO_DYNAMODB, ACCESS_LOG_DYNAMODB_TABLE_NAME, USAGE_LOG_DYNAMODB_TABLE_NAME, FEEDBACK_LOG_DYNAMODB_TABLE_NAME, LOG_FILE_NAME, FEEDBACK_LOG_FILE_NAME, USAGE_LOG_FILE_NAME, CSV_ACCESS_LOG_HEADERS, CSV_FEEDBACK_LOG_HEADERS, CSV_USAGE_LOG_HEADERS, DYNAMODB_ACCESS_LOG_HEADERS, DYNAMODB_FEEDBACK_LOG_HEADERS, DYNAMODB_USAGE_LOG_HEADERS, S3_ACCESS_LOGS_FOLDER, S3_FEEDBACK_LOGS_FOLDER, S3_USAGE_LOGS_FOLDER, AWS_ACCESS_KEY, AWS_SECRET_KEY, SHOW_EXAMPLES, HF_TOKEN, AZURE_OPENAI_API_KEY, AZURE_OPENAI_INFERENCE_ENDPOINT, LLM_TEMPERATURE, model_name_map, default_model_choice, default_source_models, default_model_source, model_sources, ensure_folder_exists
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
ensure_folder_exists(CONFIG_FOLDER)
|
| 19 |
ensure_folder_exists(OUTPUT_FOLDER)
|
|
|
|
| 24 |
ensure_folder_exists(ACCESS_LOGS_FOLDER)
|
| 25 |
ensure_folder_exists(USAGE_LOGS_FOLDER)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
today_rev = datetime.now().strftime("%Y%m%d")
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# Placeholders for example variables
|
| 30 |
in_data_files = gr.File(height=FILE_INPUT_HEIGHT, label="Choose Excel or csv files", file_count= "multiple", file_types=['.xlsx', '.xls', '.csv', '.parquet'])
|
| 31 |
in_colnames = gr.Dropdown(choices=[""], multiselect = False, label="Select the open text column of interest. In an Excel file, this shows columns across all sheets.", allow_custom_value=True, interactive=True)
|
|
|
|
| 133 |
|
| 134 |
gr.Markdown("""# Large language model topic modelling
|
| 135 |
|
| 136 |
+
Extract topics and summarise outputs using Large Language Models (LLMs, Gemma 3 4b/GPT-OSS 20b if local (see tools/config.py to modify), Gemini, Azure/OpenAI, or AWS Bedrock models (e.g. Claude, Nova models). The app will query the LLM with batches of responses to produce summary tables, which are then compared iteratively to output a table with the general topics, subtopics, topic sentiment, and a topic summary. Instructions on use can be found in the README.md file. You can try out examples by clicking on one of the example datasets below. API keys for AWS, Azure/OpenAI, and Gemini services can be entered on the settings page (note that Gemini has a free public API).
|
| 137 |
|
| 138 |
NOTE: Large language models are not 100% accurate and may produce biased or harmful outputs. All outputs from this app **absolutely need to be checked by a human** to check for harmful outputs, hallucinations, and accuracy.""")
|
| 139 |
|
|
|
|
| 169 |
|
| 170 |
with gr.Tab(label="All in one topic extraction and summarisation"):
|
| 171 |
with gr.Row():
|
| 172 |
+
model_source = gr.Dropdown(value = default_model_source, choices = model_sources, label="Large language model family", multiselect=False)
|
| 173 |
+
model_choice = gr.Dropdown(value = default_model_choice, choices = default_source_models, label="Large language model for topic extraction and summarisation", multiselect=False)
|
| 174 |
+
|
| 175 |
+
model_source.change(fn=update_model_choice, inputs=[model_source], outputs=[model_choice])
|
| 176 |
|
| 177 |
with gr.Accordion("Upload xlsx, csv, or parquet file", open = True):
|
| 178 |
in_data_files.render()
|
|
|
|
| 313 |
with gr.Accordion("Gemini API keys", open = False):
|
| 314 |
google_api_key_textbox = gr.Textbox(value = GEMINI_API_KEY, label="Enter Gemini API key (only if using Google API models)", lines=1, type="password")
|
| 315 |
|
| 316 |
+
with gr.Accordion("Azure/OpenAI Inference", open = False):
|
| 317 |
+
azure_api_key_textbox = gr.Textbox(value = AZURE_OPENAI_API_KEY, label="Enter Azure/OpenAI Inference API key (only if using Azure/OpenAI models)", lines=1, type="password")
|
| 318 |
+
azure_endpoint_textbox = gr.Textbox(value = AZURE_OPENAI_INFERENCE_ENDPOINT, label="Enter Azure/OpenAI Inference endpoint URL (only if using Azure/OpenAI models)", lines=1)
|
| 319 |
|
| 320 |
with gr.Accordion("Hugging Face token for downloading gated models", open = False):
|
| 321 |
hf_api_key_textbox = gr.Textbox(value = HF_TOKEN, label="Enter Hugging Face API key (only for gated models that need a token to download)", lines=1, type="password")
|
|
|
|
| 410 |
aws_secret_key_textbox,
|
| 411 |
hf_api_key_textbox,
|
| 412 |
azure_api_key_textbox,
|
| 413 |
+
azure_endpoint_textbox,
|
| 414 |
output_folder_state,
|
| 415 |
logged_content_df,
|
| 416 |
add_existing_topics_summary_format_textbox],
|
|
|
|
| 454 |
success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames, output_folder_state], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown], scroll_to_output=True, api_name="deduplicate_topics")
|
| 455 |
|
| 456 |
# When LLM deduplication button pressed, deduplicate data using LLM
|
| 457 |
+
def deduplicate_topics_llm_wrapper(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics=None, azure_endpoint=""):
|
| 458 |
model_source = model_name_map[model_choice]["source"]
|
| 459 |
+
return deduplicate_topics_llm(reference_df, topic_summary_df, reference_table_file_name, unique_topics_table_file_name, model_choice, in_api_key, temperature, model_source, None, None, None, None, in_excel_sheets, merge_sentiment, merge_general_topics, in_data_files, chosen_cols, output_folder, candidate_topics, azure_endpoint)
|
| 460 |
|
| 461 |
deduplicate_llm_previous_data_btn.click(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
| 462 |
+
success(deduplicate_topics_llm_wrapper, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, model_choice, google_api_key_textbox, temperature_slide, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, in_data_files, in_colnames, output_folder_state, candidate_topics, azure_endpoint_textbox], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number], scroll_to_output=True, api_name="deduplicate_topics_llm").\
|
| 463 |
success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs_llm_dedup")
|
| 464 |
|
| 465 |
# When button pressed, summarise previous data
|
|
|
|
| 467 |
success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 468 |
success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
| 469 |
success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
|
| 470 |
+
success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, azure_endpoint_textbox, logged_content_df], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], api_name="summarise_topics", show_progress_on=[output_messages_textbox, summary_output_files]).\
|
| 471 |
success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
|
| 472 |
then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_revised_summaries_state, master_unique_topics_df_revised_summaries_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[summary_output_files_xlsx, summary_xlsx_output_files_list])
|
| 473 |
|
| 474 |
# SUMMARISE WHOLE TABLE PAGE
|
| 475 |
overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
|
| 476 |
success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
|
| 477 |
+
success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox, azure_endpoint_textbox, logged_content_df], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox, logged_content_df], scroll_to_output=True, api_name="overall_summary", show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
|
| 478 |
success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
|
| 479 |
then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state, produce_structured_summary_radio], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list])
|
| 480 |
|
|
|
|
| 521 |
aws_secret_key_textbox,
|
| 522 |
hf_api_key_textbox,
|
| 523 |
azure_api_key_textbox,
|
| 524 |
+
azure_endpoint_textbox,
|
| 525 |
output_folder_state,
|
| 526 |
merge_sentiment_drop,
|
| 527 |
merge_general_topics_drop,
|
pyproject.toml
CHANGED
|
@@ -17,8 +17,7 @@ dependencies = [
|
|
| 17 |
"tabulate==0.9.0",
|
| 18 |
"lxml==5.3.0",
|
| 19 |
"google-genai==1.33.0",
|
| 20 |
-
"
|
| 21 |
-
"azure-core==1.35.0",
|
| 22 |
"html5lib==1.1",
|
| 23 |
"beautifulsoup4==4.12.3",
|
| 24 |
"rapidfuzz==3.13.0",
|
|
|
|
| 17 |
"tabulate==0.9.0",
|
| 18 |
"lxml==5.3.0",
|
| 19 |
"google-genai==1.33.0",
|
| 20 |
+
"openai==2.2.0",
|
|
|
|
| 21 |
"html5lib==1.1",
|
| 22 |
"beautifulsoup4==4.12.3",
|
| 23 |
"rapidfuzz==3.13.0",
|
requirements.txt
CHANGED
|
@@ -10,8 +10,7 @@ markdown==3.7
|
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
google-genai==1.33.0
|
| 13 |
-
|
| 14 |
-
azure-core==1.35.0
|
| 15 |
html5lib==1.1
|
| 16 |
beautifulsoup4==4.12.3
|
| 17 |
rapidfuzz==3.13.0
|
|
|
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
google-genai==1.33.0
|
| 13 |
+
openai==2.2.0
|
|
|
|
| 14 |
html5lib==1.1
|
| 15 |
beautifulsoup4==4.12.3
|
| 16 |
rapidfuzz==3.13.0
|
requirements_cpu.txt
CHANGED
|
@@ -9,8 +9,7 @@ markdown==3.7
|
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
google-genai==1.33.0
|
| 12 |
-
|
| 13 |
-
azure-core==1.35.0
|
| 14 |
html5lib==1.1
|
| 15 |
beautifulsoup4==4.12.3
|
| 16 |
rapidfuzz==3.13.0
|
|
|
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
google-genai==1.33.0
|
| 12 |
+
openai==2.2.0
|
|
|
|
| 13 |
html5lib==1.1
|
| 14 |
beautifulsoup4==4.12.3
|
| 15 |
rapidfuzz==3.13.0
|
requirements_gpu.txt
CHANGED
|
@@ -9,8 +9,7 @@ markdown==3.7
|
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
google-genai==1.33.0
|
| 12 |
-
|
| 13 |
-
azure-core==1.35.0
|
| 14 |
html5lib==1.1
|
| 15 |
beautifulsoup4==4.12.3
|
| 16 |
rapidfuzz==3.13.0
|
|
|
|
| 9 |
tabulate==0.9.0
|
| 10 |
lxml==5.3.0
|
| 11 |
google-genai==1.33.0
|
| 12 |
+
openai==2.2.0
|
|
|
|
| 13 |
html5lib==1.1
|
| 14 |
beautifulsoup4==4.12.3
|
| 15 |
rapidfuzz==3.13.0
|
requirements_no_local.txt
CHANGED
|
@@ -10,8 +10,7 @@ markdown==3.7
|
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
google-genai==1.33.0
|
| 13 |
-
|
| 14 |
-
azure-core==1.35.0
|
| 15 |
html5lib==1.1
|
| 16 |
beautifulsoup4==4.12.3
|
| 17 |
rapidfuzz==3.13.0
|
|
|
|
| 10 |
tabulate==0.9.0
|
| 11 |
lxml==5.3.0
|
| 12 |
google-genai==1.33.0
|
| 13 |
+
openai==2.2.0
|
|
|
|
| 14 |
html5lib==1.1
|
| 15 |
beautifulsoup4==4.12.3
|
| 16 |
rapidfuzz==3.13.0
|
tools/config.py
CHANGED
|
@@ -2,6 +2,8 @@ import os
|
|
| 2 |
import tempfile
|
| 3 |
import socket
|
| 4 |
import logging
|
|
|
|
|
|
|
| 5 |
from datetime import datetime
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
|
@@ -217,10 +219,10 @@ RUN_AWS_BEDROCK_MODELS = get_or_create_env_var("RUN_AWS_BEDROCK_MODELS", "1")
|
|
| 217 |
RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
| 218 |
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
| 219 |
|
| 220 |
-
# Azure AI Inference settings
|
| 221 |
-
RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "
|
| 222 |
-
|
| 223 |
-
|
| 224 |
|
| 225 |
# Build up options for models
|
| 226 |
|
|
@@ -236,27 +238,48 @@ if RUN_LOCAL_MODEL == "1" and CHOSEN_LOCAL_MODEL_TYPE:
|
|
| 236 |
model_source.append("Local")
|
| 237 |
|
| 238 |
if RUN_AWS_BEDROCK_MODELS == "1":
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
| 242 |
|
| 243 |
if RUN_GEMINI_MODELS == "1":
|
| 244 |
-
|
|
|
|
| 245 |
model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
|
| 246 |
-
model_source.extend(["Gemini"
|
| 247 |
|
| 248 |
-
# Register Azure AI models (model names must match your Azure deployments)
|
| 249 |
if RUN_AZURE_MODELS == "1":
|
| 250 |
-
# Example deployments; adjust to the deployments you actually create in Azure
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
|
|
|
| 254 |
|
| 255 |
model_name_map = {
|
| 256 |
full: {"short_name": short, "source": source}
|
| 257 |
for full, short, source in zip(model_full_names, model_short_names, model_source)
|
| 258 |
}
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
#print("model_name_map:", model_name_map)
|
| 261 |
|
| 262 |
# HF token may or may not be needed for downloading models from Hugging Face
|
|
@@ -453,4 +476,44 @@ else: OUTPUT_COST_CODES_PATH = 'config/cost_codes.csv'
|
|
| 453 |
|
| 454 |
ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
|
| 455 |
|
| 456 |
-
if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import tempfile
|
| 3 |
import socket
|
| 4 |
import logging
|
| 5 |
+
import codecs
|
| 6 |
+
from typing import List
|
| 7 |
from datetime import datetime
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
|
|
|
| 219 |
RUN_GEMINI_MODELS = get_or_create_env_var("RUN_GEMINI_MODELS", "1")
|
| 220 |
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
| 221 |
|
| 222 |
+
# Azure/OpenAI AI Inference settings
|
| 223 |
+
RUN_AZURE_MODELS = get_or_create_env_var("RUN_AZURE_MODELS", "1")
|
| 224 |
+
AZURE_OPENAI_API_KEY = get_or_create_env_var('AZURE_OPENAI_API_KEY', '')
|
| 225 |
+
AZURE_OPENAI_INFERENCE_ENDPOINT = get_or_create_env_var('AZURE_OPENAI_INFERENCE_ENDPOINT', '')
|
| 226 |
|
| 227 |
# Build up options for models
|
| 228 |
|
|
|
|
| 238 |
model_source.append("Local")
|
| 239 |
|
| 240 |
if RUN_AWS_BEDROCK_MODELS == "1":
|
| 241 |
+
amazon_models = ["anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-7-sonnet-20250219-v1:0", "anthropic.claude-sonnet-4-5-20250929-v1:0", "amazon.nova-micro-v1:0", "amazon.nova-lite-v1:0", "amazon.nova-pro-v1:0", "deepseek.v3-v1:0", "openai.gpt-oss-20b-1:0", "openai.gpt-oss-120b-1:0"]
|
| 242 |
+
model_full_names.extend(amazon_models)
|
| 243 |
+
model_short_names.extend(["haiku", "sonnet_3_7", "sonnet_4_5", "nova_micro", "nova_lite", "nova_pro", "deepseek_v3", "gpt_oss_20b_aws", "gpt_oss_120b_aws"])
|
| 244 |
+
model_source.extend(["AWS"] * len(amazon_models))
|
| 245 |
|
| 246 |
if RUN_GEMINI_MODELS == "1":
|
| 247 |
+
gemini_models = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.5-pro"]
|
| 248 |
+
model_full_names.extend(gemini_models)
|
| 249 |
model_short_names.extend(["gemini_flash_lite_2.5", "gemini_flash_2.5", "gemini_pro"])
|
| 250 |
+
model_source.extend(["Gemini"] * len(gemini_models))
|
| 251 |
|
| 252 |
+
# Register Azure/OpenAI AI models (model names must match your Azure/OpenAI deployments)
|
| 253 |
if RUN_AZURE_MODELS == "1":
|
| 254 |
+
# Example deployments; adjust to the deployments you actually create in Azure/OpenAI
|
| 255 |
+
azure_models = ["gpt-5-mini", "gpt-4o-mini"]
|
| 256 |
+
model_full_names.extend(azure_models)
|
| 257 |
+
model_short_names.extend(["gpt-5-mini", "gpt-4o-mini"])
|
| 258 |
+
model_source.extend(["Azure/OpenAI"] * len(azure_models))
|
| 259 |
|
| 260 |
model_name_map = {
|
| 261 |
full: {"short_name": short, "source": source}
|
| 262 |
for full, short, source in zip(model_full_names, model_short_names, model_source)
|
| 263 |
}
|
| 264 |
|
| 265 |
+
if RUN_LOCAL_MODEL == "1": default_model_choice = CHOSEN_LOCAL_MODEL_TYPE
|
| 266 |
+
elif RUN_AWS_FUNCTIONS == "1": default_model_choice = amazon_models[0]
|
| 267 |
+
else: default_model_choice = gemini_models[0]
|
| 268 |
+
|
| 269 |
+
default_model_source = model_name_map[default_model_choice]["source"]
|
| 270 |
+
model_sources = list(set([model_name_map[model]["source"] for model in model_full_names]))
|
| 271 |
+
|
| 272 |
+
def update_model_choice_config(default_model_source, model_name_map):
|
| 273 |
+
# Filter models by source and return the first matching model name
|
| 274 |
+
matching_models = [model_name for model_name, model_info in model_name_map.items()
|
| 275 |
+
if model_info["source"] == default_model_source]
|
| 276 |
+
|
| 277 |
+
output_model = matching_models[0] if matching_models else model_full_names[0]
|
| 278 |
+
|
| 279 |
+
return output_model, matching_models
|
| 280 |
+
|
| 281 |
+
default_model_choice, default_source_models = update_model_choice_config(default_model_source, model_name_map)
|
| 282 |
+
|
| 283 |
#print("model_name_map:", model_name_map)
|
| 284 |
|
| 285 |
# HF token may or may not be needed for downloading models from Hugging Face
|
|
|
|
| 476 |
|
| 477 |
ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
|
| 478 |
|
| 479 |
+
if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
|
| 480 |
+
|
| 481 |
+
###
|
| 482 |
+
# VALIDATE FOLDERS AND CONFIG OPTIONS
|
| 483 |
+
###
|
| 484 |
+
|
| 485 |
+
def ensure_folder_exists(output_folder:str):
|
| 486 |
+
"""Checks if the specified folder exists, creates it if not."""
|
| 487 |
+
|
| 488 |
+
if not os.path.exists(output_folder):
|
| 489 |
+
# Create the folder if it doesn't exist
|
| 490 |
+
os.makedirs(output_folder, exist_ok=True)
|
| 491 |
+
print(f"Created the {output_folder} folder.")
|
| 492 |
+
else:
|
| 493 |
+
pass
|
| 494 |
+
#print(f"The {output_folder} folder already exists.")
|
| 495 |
+
|
| 496 |
+
def _get_env_list(env_var_name: str, strip_strings:bool=True) -> List[str]:
|
| 497 |
+
"""Parses a comma-separated environment variable into a list of strings."""
|
| 498 |
+
value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
|
| 499 |
+
if not value:
|
| 500 |
+
return []
|
| 501 |
+
# Split by comma and filter out any empty strings that might result from extra commas
|
| 502 |
+
if strip_strings:
|
| 503 |
+
return [s.strip() for s in value.split(',') if s.strip()]
|
| 504 |
+
else:
|
| 505 |
+
return [codecs.decode(s, 'unicode_escape') for s in value.split(',') if s]
|
| 506 |
+
|
| 507 |
+
# Convert string environment variables to string or list
|
| 508 |
+
if SAVE_LOGS_TO_CSV == "True": SAVE_LOGS_TO_CSV = True
|
| 509 |
+
else: SAVE_LOGS_TO_CSV = False
|
| 510 |
+
if SAVE_LOGS_TO_DYNAMODB == "True": SAVE_LOGS_TO_DYNAMODB = True
|
| 511 |
+
else: SAVE_LOGS_TO_DYNAMODB = False
|
| 512 |
+
|
| 513 |
+
if CSV_ACCESS_LOG_HEADERS: CSV_ACCESS_LOG_HEADERS = _get_env_list(CSV_ACCESS_LOG_HEADERS)
|
| 514 |
+
if CSV_FEEDBACK_LOG_HEADERS: CSV_FEEDBACK_LOG_HEADERS = _get_env_list(CSV_FEEDBACK_LOG_HEADERS)
|
| 515 |
+
if CSV_USAGE_LOG_HEADERS: CSV_USAGE_LOG_HEADERS = _get_env_list(CSV_USAGE_LOG_HEADERS)
|
| 516 |
+
|
| 517 |
+
if DYNAMODB_ACCESS_LOG_HEADERS: DYNAMODB_ACCESS_LOG_HEADERS = _get_env_list(DYNAMODB_ACCESS_LOG_HEADERS)
|
| 518 |
+
if DYNAMODB_FEEDBACK_LOG_HEADERS: DYNAMODB_FEEDBACK_LOG_HEADERS = _get_env_list(DYNAMODB_FEEDBACK_LOG_HEADERS)
|
| 519 |
+
if DYNAMODB_USAGE_LOG_HEADERS: DYNAMODB_USAGE_LOG_HEADERS = _get_env_list(DYNAMODB_USAGE_LOG_HEADERS)
|
tools/dedup_summaries.py
CHANGED
|
@@ -11,10 +11,10 @@ from tqdm import tqdm
|
|
| 11 |
import os
|
| 12 |
from tools.llm_api_call import generate_zero_shot_topics_df
|
| 13 |
from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill, llm_deduplication_system_prompt, llm_deduplication_prompt, llm_deduplication_prompt_with_candidates
|
| 14 |
-
from tools.llm_funcs import construct_gemini_generative_model, process_requests,
|
| 15 |
from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details, read_file
|
| 16 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 17 |
-
from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, LLM_MAX_NEW_TOKENS, LLM_SEED, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX,
|
| 18 |
|
| 19 |
max_tokens = LLM_MAX_NEW_TOKENS
|
| 20 |
timeout_wait = TIMEOUT_WAIT
|
|
@@ -393,7 +393,8 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
|
|
| 393 |
in_data_files:List[str]=list(),
|
| 394 |
chosen_cols:List[str]="",
|
| 395 |
output_folder:str=OUTPUT_FOLDER,
|
| 396 |
-
candidate_topics=None
|
|
|
|
| 397 |
):
|
| 398 |
'''
|
| 399 |
Deduplicate topics using LLM semantic understanding to identify and merge similar topics.
|
|
@@ -501,7 +502,7 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
|
|
| 501 |
|
| 502 |
# Set up model clients based on model source
|
| 503 |
if "Gemini" in model_source:
|
| 504 |
-
|
| 505 |
in_api_key, temperature, model_choice, llm_deduplication_system_prompt,
|
| 506 |
max_tokens, LLM_SEED
|
| 507 |
)
|
|
@@ -509,13 +510,13 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
|
|
| 509 |
elif "AWS" in model_source:
|
| 510 |
if not bedrock_runtime:
|
| 511 |
bedrock_runtime = boto3.client('bedrock-runtime')
|
| 512 |
-
|
| 513 |
config = None
|
| 514 |
-
elif "Azure" in model_source:
|
| 515 |
-
|
| 516 |
bedrock_runtime = None
|
| 517 |
elif "Local" in model_source:
|
| 518 |
-
|
| 519 |
config = None
|
| 520 |
bedrock_runtime = None
|
| 521 |
else:
|
|
@@ -531,8 +532,8 @@ def deduplicate_topics_llm(reference_df:pd.DataFrame,
|
|
| 531 |
conversation_history=conversation_history,
|
| 532 |
whole_conversation=whole_conversation,
|
| 533 |
whole_conversation_metadata=whole_conversation_metadata,
|
| 534 |
-
|
| 535 |
-
|
| 536 |
model_choice=model_choice,
|
| 537 |
temperature=temperature,
|
| 538 |
reported_batch_no=1,
|
|
@@ -758,7 +759,7 @@ def sample_reference_table_summaries(reference_df:pd.DataFrame,
|
|
| 758 |
|
| 759 |
return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
|
| 760 |
|
| 761 |
-
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list(), tokenizer=list(), assistant_model=list()):
|
| 762 |
"""
|
| 763 |
Query an LLM to generate a summary of topics based on the provided prompts.
|
| 764 |
|
|
@@ -780,16 +781,15 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
| 780 |
"""
|
| 781 |
conversation_history = list()
|
| 782 |
whole_conversation_metadata = list()
|
| 783 |
-
|
| 784 |
-
|
| 785 |
|
| 786 |
# Prepare Gemini models before query
|
| 787 |
if "Gemini" in model_source:
|
| 788 |
#print("Using Gemini model:", model_choice)
|
| 789 |
-
|
| 790 |
-
elif "Azure" in model_source:
|
| 791 |
-
|
| 792 |
-
google_client, config = construct_azure_client(in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), endpoint=AZURE_INFERENCE_ENDPOINT)
|
| 793 |
elif "Local" in model_source:
|
| 794 |
pass
|
| 795 |
#print("Using local model: ", model_choice)
|
|
@@ -800,7 +800,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
|
|
| 800 |
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
| 801 |
|
| 802 |
# Process requests to large language model
|
| 803 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata,
|
| 804 |
|
| 805 |
summarised_output = re.sub(r'\n{2,}', '\n', response_text) # Replace multiple line breaks with a single line break
|
| 806 |
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
|
@@ -898,6 +898,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
|
|
| 898 |
aws_secret_key_textbox:str='',
|
| 899 |
model_name_map:dict=model_name_map,
|
| 900 |
hf_api_key_textbox:str='',
|
|
|
|
| 901 |
existing_logged_content:list=list(),
|
| 902 |
output_debug_files:str=output_debug_files,
|
| 903 |
reasoning_suffix:str=reasoning_suffix,
|
|
@@ -1041,7 +1042,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
|
|
| 1041 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
|
| 1042 |
|
| 1043 |
try:
|
| 1044 |
-
response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model)
|
| 1045 |
summarised_output = response_text
|
| 1046 |
except Exception as e:
|
| 1047 |
print("Creating summary failed:", e)
|
|
@@ -1183,6 +1184,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
| 1183 |
aws_secret_key_textbox:str='',
|
| 1184 |
model_name_map:dict=model_name_map,
|
| 1185 |
hf_api_key_textbox:str='',
|
|
|
|
| 1186 |
existing_logged_content:list=list(),
|
| 1187 |
output_debug_files:str=output_debug_files,
|
| 1188 |
log_output_files:list=list(),
|
|
@@ -1313,7 +1315,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
|
|
| 1313 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
|
| 1314 |
|
| 1315 |
try:
|
| 1316 |
-
response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model)
|
| 1317 |
summarised_output_for_df = response_text
|
| 1318 |
summarised_output = response
|
| 1319 |
except Exception as e:
|
|
|
|
| 11 |
import os
|
| 12 |
from tools.llm_api_call import generate_zero_shot_topics_df
|
| 13 |
from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_descriptions_system_prompt, system_prompt, summarise_everything_prompt, comprehensive_summary_format_prompt, summarise_everything_system_prompt, comprehensive_summary_format_prompt_by_group, summary_assistant_prefill, llm_deduplication_system_prompt, llm_deduplication_prompt, llm_deduplication_prompt_with_candidates
|
| 14 |
+
from tools.llm_funcs import construct_gemini_generative_model, process_requests, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model, construct_gemini_generative_model, construct_azure_client, call_llm_with_markdown_table_checks
|
| 15 |
from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details, read_file
|
| 16 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 17 |
+
from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, LLM_MAX_NEW_TOKENS, LLM_SEED, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_OPENAI_INFERENCE_ENDPOINT, MAX_SPACES_GPU_RUN_TIME, OUTPUT_DEBUG_FILES
|
| 18 |
|
| 19 |
max_tokens = LLM_MAX_NEW_TOKENS
|
| 20 |
timeout_wait = TIMEOUT_WAIT
|
|
|
|
| 393 |
in_data_files:List[str]=list(),
|
| 394 |
chosen_cols:List[str]="",
|
| 395 |
output_folder:str=OUTPUT_FOLDER,
|
| 396 |
+
candidate_topics=None,
|
| 397 |
+
azure_endpoint:str=""
|
| 398 |
):
|
| 399 |
'''
|
| 400 |
Deduplicate topics using LLM semantic understanding to identify and merge similar topics.
|
|
|
|
| 502 |
|
| 503 |
# Set up model clients based on model source
|
| 504 |
if "Gemini" in model_source:
|
| 505 |
+
client, config = construct_gemini_generative_model(
|
| 506 |
in_api_key, temperature, model_choice, llm_deduplication_system_prompt,
|
| 507 |
max_tokens, LLM_SEED
|
| 508 |
)
|
|
|
|
| 510 |
elif "AWS" in model_source:
|
| 511 |
if not bedrock_runtime:
|
| 512 |
bedrock_runtime = boto3.client('bedrock-runtime')
|
| 513 |
+
client = None
|
| 514 |
config = None
|
| 515 |
+
elif "Azure/OpenAI" in model_source:
|
| 516 |
+
client, config = construct_azure_client(in_api_key, azure_endpoint)
|
| 517 |
bedrock_runtime = None
|
| 518 |
elif "Local" in model_source:
|
| 519 |
+
client = None
|
| 520 |
config = None
|
| 521 |
bedrock_runtime = None
|
| 522 |
else:
|
|
|
|
| 532 |
conversation_history=conversation_history,
|
| 533 |
whole_conversation=whole_conversation,
|
| 534 |
whole_conversation_metadata=whole_conversation_metadata,
|
| 535 |
+
client=client,
|
| 536 |
+
client_config=config,
|
| 537 |
model_choice=model_choice,
|
| 538 |
temperature=temperature,
|
| 539 |
reported_batch_no=1,
|
|
|
|
| 759 |
|
| 760 |
return sampled_reference_table_df, summarised_references_markdown#, reference_df, topic_summary_df
|
| 761 |
|
| 762 |
+
def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:float, formatted_summary_prompt:str, summarise_topic_descriptions_system_prompt:str, model_source:str, bedrock_runtime:boto3.Session.client, local_model=list(), tokenizer=list(), assistant_model=list(), azure_endpoint:str=""):
|
| 763 |
"""
|
| 764 |
Query an LLM to generate a summary of topics based on the provided prompts.
|
| 765 |
|
|
|
|
| 781 |
"""
|
| 782 |
conversation_history = list()
|
| 783 |
whole_conversation_metadata = list()
|
| 784 |
+
client = list()
|
| 785 |
+
client_config = {}
|
| 786 |
|
| 787 |
# Prepare Gemini models before query
|
| 788 |
if "Gemini" in model_source:
|
| 789 |
#print("Using Gemini model:", model_choice)
|
| 790 |
+
client, config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=system_prompt, max_tokens=max_tokens)
|
| 791 |
+
elif "Azure/OpenAI" in model_source:
|
| 792 |
+
client, config = construct_azure_client(in_api_key=os.environ.get("AZURE_INFERENCE_CREDENTIAL", ""), endpoint=azure_endpoint)
|
|
|
|
| 793 |
elif "Local" in model_source:
|
| 794 |
pass
|
| 795 |
#print("Using local model: ", model_choice)
|
|
|
|
| 800 |
whole_conversation = [summarise_topic_descriptions_system_prompt]
|
| 801 |
|
| 802 |
# Process requests to large language model
|
| 803 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(formatted_summary_prompt, system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, client, client_config, model_choice, temperature, bedrock_runtime=bedrock_runtime, model_source=model_source, local_model=local_model, tokenizer=tokenizer, assistant_model=assistant_model, assistant_prefill=summary_assistant_prefill)
|
| 804 |
|
| 805 |
summarised_output = re.sub(r'\n{2,}', '\n', response_text) # Replace multiple line breaks with a single line break
|
| 806 |
summarised_output = re.sub(r'^\n{1,}', '', summarised_output) # Remove one or more line breaks at the start
|
|
|
|
| 898 |
aws_secret_key_textbox:str='',
|
| 899 |
model_name_map:dict=model_name_map,
|
| 900 |
hf_api_key_textbox:str='',
|
| 901 |
+
azure_endpoint_textbox:str='',
|
| 902 |
existing_logged_content:list=list(),
|
| 903 |
output_debug_files:str=output_debug_files,
|
| 904 |
reasoning_suffix:str=reasoning_suffix,
|
|
|
|
| 1042 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_topic_descriptions_system_prompt = formatted_summarise_topic_descriptions_system_prompt + "\n" + reasoning_suffix
|
| 1043 |
|
| 1044 |
try:
|
| 1045 |
+
response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_topic_descriptions_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model, azure_endpoint=azure_endpoint_textbox)
|
| 1046 |
summarised_output = response_text
|
| 1047 |
except Exception as e:
|
| 1048 |
print("Creating summary failed:", e)
|
|
|
|
| 1184 |
aws_secret_key_textbox:str='',
|
| 1185 |
model_name_map:dict=model_name_map,
|
| 1186 |
hf_api_key_textbox:str='',
|
| 1187 |
+
azure_endpoint_textbox:str='',
|
| 1188 |
existing_logged_content:list=list(),
|
| 1189 |
output_debug_files:str=output_debug_files,
|
| 1190 |
log_output_files:list=list(),
|
|
|
|
| 1315 |
if "Local" in model_source and reasoning_suffix: formatted_summarise_everything_system_prompt = formatted_summarise_everything_system_prompt + "\n" + reasoning_suffix
|
| 1316 |
|
| 1317 |
try:
|
| 1318 |
+
response, conversation_history, metadata, response_text = summarise_output_topics_query(model_choice, in_api_key, temperature, formatted_summary_prompt, formatted_summarise_everything_system_prompt, model_source, bedrock_runtime, local_model, tokenizer=tokenizer, assistant_model=assistant_model, azure_endpoint=azure_endpoint_textbox)
|
| 1319 |
summarised_output_for_df = response_text
|
| 1320 |
summarised_output = response
|
| 1321 |
except Exception as e:
|
tools/helper_functions.py
CHANGED
|
@@ -6,8 +6,9 @@ import pandas as pd
|
|
| 6 |
import numpy as np
|
| 7 |
from typing import List
|
| 8 |
import math
|
|
|
|
| 9 |
from botocore.exceptions import ClientError
|
| 10 |
-
from tools.config import OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, CUSTOM_HEADER, CUSTOM_HEADER_VALUE, AWS_USER_POOL_ID, MAXIMUM_ZERO_SHOT_TOPICS
|
| 11 |
|
| 12 |
def empty_output_vars_extract_topics():
|
| 13 |
# Empty output objects before processing a new file
|
|
@@ -745,8 +746,6 @@ def enforce_cost_codes(enforce_cost_code_textbox:str, cost_code_choice:str, cost
|
|
| 745 |
raise Exception("Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions.")
|
| 746 |
return
|
| 747 |
|
| 748 |
-
import codecs
|
| 749 |
-
|
| 750 |
def _get_env_list(env_var_name: str, strip_strings:bool=True) -> List[str]:
|
| 751 |
"""Parses a comma-separated environment variable into a list of strings."""
|
| 752 |
value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
|
|
@@ -898,4 +897,13 @@ def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
|
|
| 898 |
"Description": zero_shot_topics_description_list
|
| 899 |
})
|
| 900 |
|
| 901 |
-
return zero_shot_topics_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from typing import List
|
| 8 |
import math
|
| 9 |
+
import codecs
|
| 10 |
from botocore.exceptions import ClientError
|
| 11 |
+
from tools.config import OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, CUSTOM_HEADER, CUSTOM_HEADER_VALUE, AWS_USER_POOL_ID, MAXIMUM_ZERO_SHOT_TOPICS, model_name_map, model_full_names
|
| 12 |
|
| 13 |
def empty_output_vars_extract_topics():
|
| 14 |
# Empty output objects before processing a new file
|
|
|
|
| 746 |
raise Exception("Selected cost code not found in list. Please contact Finance if you cannot find the correct cost code from the given list of suggestions.")
|
| 747 |
return
|
| 748 |
|
|
|
|
|
|
|
| 749 |
def _get_env_list(env_var_name: str, strip_strings:bool=True) -> List[str]:
|
| 750 |
"""Parses a comma-separated environment variable into a list of strings."""
|
| 751 |
value = env_var_name[1:-1].strip().replace('\"', '').replace("\'","")
|
|
|
|
| 897 |
"Description": zero_shot_topics_description_list
|
| 898 |
})
|
| 899 |
|
| 900 |
+
return zero_shot_topics_df
|
| 901 |
+
|
| 902 |
+
def update_model_choice(model_source):
|
| 903 |
+
# Filter models by source and return the first matching model name
|
| 904 |
+
matching_models = [model_name for model_name, model_info in model_name_map.items()
|
| 905 |
+
if model_info["source"] == model_source]
|
| 906 |
+
|
| 907 |
+
output_model = matching_models[0] if matching_models else model_full_names[0]
|
| 908 |
+
|
| 909 |
+
return gr.Dropdown(value = output_model, choices = matching_models, label="Large language model for topic extraction and summarisation", multiselect=False)
|
tools/llm_api_call.py
CHANGED
|
@@ -18,7 +18,7 @@ GradioFileData = gr.FileData
|
|
| 18 |
from tools.prompts import initial_table_prompt, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt, default_response_reference_format, negative_neutral_positive_sentiment_prompt, negative_or_positive_sentiment_prompt, default_sentiment_prompt
|
| 19 |
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details, move_overall_summary_output_files_to_front_page, generate_zero_shot_topics_df
|
| 20 |
from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
|
| 21 |
-
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, LLM_MAX_NEW_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX,
|
| 22 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 23 |
from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary, process_debug_output_iteration
|
| 24 |
from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
|
|
@@ -620,6 +620,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 620 |
aws_secret_key_textbox:str='',
|
| 621 |
hf_api_key_textbox:str='',
|
| 622 |
azure_api_key_textbox:str='',
|
|
|
|
| 623 |
max_tokens:int=max_tokens,
|
| 624 |
model_name_map:dict=model_name_map,
|
| 625 |
existing_logged_content:list=list(),
|
|
@@ -635,7 +636,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 635 |
progress=Progress(track_tqdm=False)):
|
| 636 |
|
| 637 |
'''
|
| 638 |
-
Query an LLM (local, (Gemma/GPT-OSS if local, Gemini, AWS Bedrock or Azure AI Inference) with up to three prompts about a table of open text data. Up to 'batch_size' rows will be queried at a time.
|
| 639 |
|
| 640 |
Parameters:
|
| 641 |
- in_data_file (gr.File): Gradio file object containing input data
|
|
@@ -693,8 +694,8 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 693 |
|
| 694 |
tic = time.perf_counter()
|
| 695 |
|
| 696 |
-
|
| 697 |
-
|
| 698 |
final_time = 0.0
|
| 699 |
whole_conversation_metadata = list()
|
| 700 |
is_error = False
|
|
@@ -822,13 +823,13 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 822 |
# Prepare clients before query
|
| 823 |
if "Gemini" in model_source:
|
| 824 |
#print("Using Gemini model:", model_choice)
|
| 825 |
-
|
| 826 |
-
elif "Azure" in model_source:
|
| 827 |
-
#print("Using Azure AI Inference model:", model_choice)
|
| 828 |
# If provided, set env for downstream calls too
|
| 829 |
if azure_api_key_textbox:
|
| 830 |
os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
|
| 831 |
-
|
| 832 |
elif "anthropic.claude" in model_choice:
|
| 833 |
#print("Using AWS Bedrock model:", model_choice)
|
| 834 |
pass
|
|
@@ -949,7 +950,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 949 |
whole_conversation = list()
|
| 950 |
|
| 951 |
# Process requests to large language model
|
| 952 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata,
|
| 953 |
|
| 954 |
# Return output tables
|
| 955 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=False, output_folder=output_folder)
|
|
@@ -1011,12 +1012,12 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 1011 |
# Prepare Gemini models before query
|
| 1012 |
if model_source == "Gemini":
|
| 1013 |
print("Using Gemini model:", model_choice)
|
| 1014 |
-
|
| 1015 |
-
elif model_source == "Azure":
|
| 1016 |
-
print("Using Azure AI Inference model:", model_choice)
|
| 1017 |
if azure_api_key_textbox:
|
| 1018 |
os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
|
| 1019 |
-
|
| 1020 |
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
| 1021 |
pass
|
| 1022 |
#print("Using local model:", model_choice)
|
|
@@ -1038,7 +1039,7 @@ def extract_topics(in_data_file: GradioFileData,
|
|
| 1038 |
|
| 1039 |
whole_conversation = list()
|
| 1040 |
|
| 1041 |
-
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata,
|
| 1042 |
|
| 1043 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=True, output_folder=output_folder)
|
| 1044 |
|
|
@@ -1243,6 +1244,7 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1243 |
aws_secret_key_textbox:str="",
|
| 1244 |
hf_api_key_textbox:str="",
|
| 1245 |
azure_api_key_textbox:str="",
|
|
|
|
| 1246 |
output_folder: str = OUTPUT_FOLDER,
|
| 1247 |
existing_logged_content:list=list(),
|
| 1248 |
additional_instructions_summary_format:str="",
|
|
@@ -1296,7 +1298,7 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1296 |
:param aws_access_key_textbox: AWS access key for Bedrock.
|
| 1297 |
:param aws_secret_key_textbox: AWS secret key for Bedrock.
|
| 1298 |
:param hf_api_key_textbox: Hugging Face API key for local models.
|
| 1299 |
-
:param azure_api_key_textbox: Azure API key for Azure AI Inference.
|
| 1300 |
:param output_folder: The folder where output files will be saved.
|
| 1301 |
:param existing_logged_content: A list of existing logged content.
|
| 1302 |
:param force_single_topic_prompt: Prompt for forcing a single topic.
|
|
@@ -1475,6 +1477,7 @@ def wrapper_extract_topics_per_column_value(
|
|
| 1475 |
aws_secret_key_textbox=aws_secret_key_textbox,
|
| 1476 |
hf_api_key_textbox=hf_api_key_textbox,
|
| 1477 |
azure_api_key_textbox=azure_api_key_textbox,
|
|
|
|
| 1478 |
max_tokens=max_tokens,
|
| 1479 |
model_name_map=model_name_map,
|
| 1480 |
max_time_for_loop=max_time_for_loop,
|
|
@@ -1736,6 +1739,7 @@ def all_in_one_pipeline(
|
|
| 1736 |
aws_secret_key_text: str,
|
| 1737 |
hf_api_key_text: str,
|
| 1738 |
azure_api_key_text: str,
|
|
|
|
| 1739 |
output_folder: str = OUTPUT_FOLDER,
|
| 1740 |
merge_sentiment: str = "No",
|
| 1741 |
merge_general_topics: str = "Yes",
|
|
@@ -1790,7 +1794,7 @@ def all_in_one_pipeline(
|
|
| 1790 |
aws_access_key_text (str): AWS access key.
|
| 1791 |
aws_secret_key_text (str): AWS secret key.
|
| 1792 |
hf_api_key_text (str): Hugging Face API key.
|
| 1793 |
-
azure_api_key_text (str): Azure API key.
|
| 1794 |
output_folder (str, optional): Folder to save output files. Defaults to OUTPUT_FOLDER.
|
| 1795 |
merge_sentiment (str, optional): Whether to merge sentiment. Defaults to "No".
|
| 1796 |
merge_general_topics (str, optional): Whether to merge general topics. Defaults to "Yes".
|
|
@@ -1884,6 +1888,7 @@ def all_in_one_pipeline(
|
|
| 1884 |
aws_secret_key_textbox=aws_secret_key_text,
|
| 1885 |
hf_api_key_textbox=hf_api_key_text,
|
| 1886 |
azure_api_key_textbox=azure_api_key_text,
|
|
|
|
| 1887 |
output_folder=output_folder,
|
| 1888 |
existing_logged_content=existing_logged_content,
|
| 1889 |
model_name_map=model_name_map_state,
|
|
|
|
| 18 |
from tools.prompts import initial_table_prompt, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt, default_response_reference_format, negative_neutral_positive_sentiment_prompt, negative_or_positive_sentiment_prompt, default_sentiment_prompt
|
| 19 |
from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details, move_overall_summary_output_files_to_front_page, generate_zero_shot_topics_df
|
| 20 |
from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
|
| 21 |
+
from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, LLM_MAX_NEW_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_OPENAI_INFERENCE_ENDPOINT, MAX_ROWS, MAXIMUM_ZERO_SHOT_TOPICS, MAX_SPACES_GPU_RUN_TIME, OUTPUT_DEBUG_FILES
|
| 22 |
from tools.aws_functions import connect_to_bedrock_runtime
|
| 23 |
from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary, process_debug_output_iteration
|
| 24 |
from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
|
|
|
|
| 620 |
aws_secret_key_textbox:str='',
|
| 621 |
hf_api_key_textbox:str='',
|
| 622 |
azure_api_key_textbox:str='',
|
| 623 |
+
azure_endpoint_textbox:str='',
|
| 624 |
max_tokens:int=max_tokens,
|
| 625 |
model_name_map:dict=model_name_map,
|
| 626 |
existing_logged_content:list=list(),
|
|
|
|
| 636 |
progress=Progress(track_tqdm=False)):
|
| 637 |
|
| 638 |
'''
|
| 639 |
+
Query an LLM (local, (Gemma/GPT-OSS if local, Gemini, AWS Bedrock or Azure/OpenAI AI Inference) with up to three prompts about a table of open text data. Up to 'batch_size' rows will be queried at a time.
|
| 640 |
|
| 641 |
Parameters:
|
| 642 |
- in_data_file (gr.File): Gradio file object containing input data
|
|
|
|
| 694 |
|
| 695 |
tic = time.perf_counter()
|
| 696 |
|
| 697 |
+
client = list()
|
| 698 |
+
client_config = {}
|
| 699 |
final_time = 0.0
|
| 700 |
whole_conversation_metadata = list()
|
| 701 |
is_error = False
|
|
|
|
| 823 |
# Prepare clients before query
|
| 824 |
if "Gemini" in model_source:
|
| 825 |
#print("Using Gemini model:", model_choice)
|
| 826 |
+
client, client_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
|
| 827 |
+
elif "Azure/OpenAI" in model_source:
|
| 828 |
+
#print("Using Azure/OpenAI AI Inference model:", model_choice)
|
| 829 |
# If provided, set env for downstream calls too
|
| 830 |
if azure_api_key_textbox:
|
| 831 |
os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
|
| 832 |
+
client, client_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=azure_endpoint_textbox)
|
| 833 |
elif "anthropic.claude" in model_choice:
|
| 834 |
#print("Using AWS Bedrock model:", model_choice)
|
| 835 |
pass
|
|
|
|
| 950 |
whole_conversation = list()
|
| 951 |
|
| 952 |
# Process requests to large language model
|
| 953 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(summary_prompt_list, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, client, client_config, model_choice, temperature, reported_batch_no, local_model, tokenizer, bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=add_existing_topics_assistant_prefill, master = True)
|
| 954 |
|
| 955 |
# Return output tables
|
| 956 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, new_topic_df, new_reference_df, new_topic_summary_df, master_batch_out_file_part, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=False, output_folder=output_folder)
|
|
|
|
| 1012 |
# Prepare Gemini models before query
|
| 1013 |
if model_source == "Gemini":
|
| 1014 |
print("Using Gemini model:", model_choice)
|
| 1015 |
+
client, client_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
|
| 1016 |
+
elif model_source == "Azure/OpenAI":
|
| 1017 |
+
print("Using Azure/OpenAI AI Inference model:", model_choice)
|
| 1018 |
if azure_api_key_textbox:
|
| 1019 |
os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
|
| 1020 |
+
client, client_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=azure_endpoint_textbox)
|
| 1021 |
elif model_choice == CHOSEN_LOCAL_MODEL_TYPE:
|
| 1022 |
pass
|
| 1023 |
#print("Using local model:", model_choice)
|
|
|
|
| 1039 |
|
| 1040 |
whole_conversation = list()
|
| 1041 |
|
| 1042 |
+
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(batch_prompts, formatted_system_prompt, conversation_history, whole_conversation, whole_conversation_metadata, client, client_config, model_choice, temperature, reported_batch_no, local_model, tokenizer,bedrock_runtime, model_source, MAX_OUTPUT_VALIDATION_ATTEMPTS, assistant_prefill=initial_table_assistant_prefill)
|
| 1043 |
|
| 1044 |
topic_table_out_path, reference_table_out_path, topic_summary_df_out_path, topic_table_df, reference_df, new_topic_summary_df, batch_file_path_details, is_error = write_llm_output_and_logs(response_text, whole_conversation, whole_conversation_metadata, file_name, latest_batch_completed, start_row, end_row, model_choice_clean, temperature, log_files_output_paths, existing_reference_df, existing_topic_summary_df, batch_size, chosen_cols, batch_basic_response_df, model_name_map, group_name, produce_structured_summary_radio, first_run=True, output_folder=output_folder)
|
| 1045 |
|
|
|
|
| 1244 |
aws_secret_key_textbox:str="",
|
| 1245 |
hf_api_key_textbox:str="",
|
| 1246 |
azure_api_key_textbox:str="",
|
| 1247 |
+
azure_endpoint_textbox:str="",
|
| 1248 |
output_folder: str = OUTPUT_FOLDER,
|
| 1249 |
existing_logged_content:list=list(),
|
| 1250 |
additional_instructions_summary_format:str="",
|
|
|
|
| 1298 |
:param aws_access_key_textbox: AWS access key for Bedrock.
|
| 1299 |
:param aws_secret_key_textbox: AWS secret key for Bedrock.
|
| 1300 |
:param hf_api_key_textbox: Hugging Face API key for local models.
|
| 1301 |
+
:param azure_api_key_textbox: Azure/OpenAI API key for Azure/OpenAI AI Inference.
|
| 1302 |
:param output_folder: The folder where output files will be saved.
|
| 1303 |
:param existing_logged_content: A list of existing logged content.
|
| 1304 |
:param force_single_topic_prompt: Prompt for forcing a single topic.
|
|
|
|
| 1477 |
aws_secret_key_textbox=aws_secret_key_textbox,
|
| 1478 |
hf_api_key_textbox=hf_api_key_textbox,
|
| 1479 |
azure_api_key_textbox=azure_api_key_textbox,
|
| 1480 |
+
azure_endpoint_textbox=azure_endpoint_textbox,
|
| 1481 |
max_tokens=max_tokens,
|
| 1482 |
model_name_map=model_name_map,
|
| 1483 |
max_time_for_loop=max_time_for_loop,
|
|
|
|
| 1739 |
aws_secret_key_text: str,
|
| 1740 |
hf_api_key_text: str,
|
| 1741 |
azure_api_key_text: str,
|
| 1742 |
+
azure_endpoint_text: str,
|
| 1743 |
output_folder: str = OUTPUT_FOLDER,
|
| 1744 |
merge_sentiment: str = "No",
|
| 1745 |
merge_general_topics: str = "Yes",
|
|
|
|
| 1794 |
aws_access_key_text (str): AWS access key.
|
| 1795 |
aws_secret_key_text (str): AWS secret key.
|
| 1796 |
hf_api_key_text (str): Hugging Face API key.
|
| 1797 |
+
azure_api_key_text (str): Azure/OpenAI API key.
|
| 1798 |
output_folder (str, optional): Folder to save output files. Defaults to OUTPUT_FOLDER.
|
| 1799 |
merge_sentiment (str, optional): Whether to merge sentiment. Defaults to "No".
|
| 1800 |
merge_general_topics (str, optional): Whether to merge general topics. Defaults to "Yes".
|
|
|
|
| 1888 |
aws_secret_key_textbox=aws_secret_key_text,
|
| 1889 |
hf_api_key_textbox=hf_api_key_text,
|
| 1890 |
azure_api_key_textbox=azure_api_key_text,
|
| 1891 |
+
azure_endpoint_textbox=azure_endpoint_text,
|
| 1892 |
output_folder=output_folder,
|
| 1893 |
existing_logged_content=existing_logged_content,
|
| 1894 |
model_name_map=model_name_map_state,
|
tools/llm_funcs.py
CHANGED
|
@@ -10,10 +10,7 @@ from typing import List, Tuple, TypeVar
|
|
| 10 |
from google import genai as ai
|
| 11 |
from google.genai import types
|
| 12 |
from gradio import Progress
|
| 13 |
-
|
| 14 |
-
from azure.ai.inference import ChatCompletionsClient
|
| 15 |
-
from azure.core.credentials import AzureKeyCredential
|
| 16 |
-
from azure.ai.inference.models import SystemMessage, UserMessage
|
| 17 |
|
| 18 |
model_type = None # global variable setup
|
| 19 |
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
|
@@ -674,27 +671,31 @@ def construct_gemini_generative_model(in_api_key: str, temperature: float, model
|
|
| 674 |
|
| 675 |
def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
|
| 676 |
"""
|
| 677 |
-
Constructs
|
| 678 |
"""
|
| 679 |
try:
|
| 680 |
key = None
|
| 681 |
if in_api_key:
|
| 682 |
key = in_api_key
|
| 683 |
-
elif os.environ.get("
|
| 684 |
-
key = os.environ["
|
| 685 |
-
elif os.environ.get("AZURE_API_KEY"):
|
| 686 |
-
key = os.environ["AZURE_API_KEY"]
|
| 687 |
if not key:
|
| 688 |
-
raise Warning("No Azure API key found.")
|
| 689 |
|
| 690 |
if not endpoint:
|
| 691 |
-
endpoint = os.environ.get("
|
| 692 |
if not endpoint:
|
| 693 |
-
raise Warning("No Azure inference endpoint found.")
|
| 694 |
-
|
| 695 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
except Exception as e:
|
| 697 |
-
print("Error constructing Azure
|
| 698 |
raise
|
| 699 |
|
| 700 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
|
|
@@ -756,7 +757,15 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
| 756 |
)
|
| 757 |
|
| 758 |
output_message = api_response['output']['message']
|
| 759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 760 |
|
| 761 |
# The usage statistics are neatly provided in the 'usage' key.
|
| 762 |
usage = api_response['usage']
|
|
@@ -803,9 +812,6 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 803 |
{"role": "system", "content": system_prompt},
|
| 804 |
{"role": "user", "content": prompt}
|
| 805 |
]
|
| 806 |
-
#print("Conversation:", conversation)
|
| 807 |
-
#import pprint
|
| 808 |
-
#pprint.pprint(conversation)
|
| 809 |
|
| 810 |
# 2. Apply the chat template
|
| 811 |
# This function formats the conversation into the exact string Gemma 3 expects.
|
|
@@ -820,9 +826,6 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 820 |
).to("cuda")
|
| 821 |
except Exception as e:
|
| 822 |
print("Error applying chat template:", e)
|
| 823 |
-
print("Conversation type:", type(conversation))
|
| 824 |
-
for turn in conversation:
|
| 825 |
-
print("Turn type:", type(turn), "Content type:", type(turn.get("content")))
|
| 826 |
raise
|
| 827 |
|
| 828 |
# Map LlamaCPP parameters to transformers parameters
|
|
@@ -850,7 +853,7 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 850 |
|
| 851 |
# Use speculative decoding if assistant model is available
|
| 852 |
if speculative_decoding and assistant_model is not None:
|
| 853 |
-
print("Using speculative decoding with assistant model")
|
| 854 |
outputs = model.generate(
|
| 855 |
input_ids,
|
| 856 |
assistant_model=assistant_model,
|
|
@@ -858,7 +861,7 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 858 |
streamer = streamer
|
| 859 |
)
|
| 860 |
else:
|
| 861 |
-
print("Generating without speculative decoding")
|
| 862 |
outputs = model.generate(
|
| 863 |
input_ids,
|
| 864 |
**generation_kwargs,
|
|
@@ -868,11 +871,9 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 868 |
end_time = time.time()
|
| 869 |
|
| 870 |
# --- Decode and Display Results ---
|
| 871 |
-
#generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 872 |
-
# To get only the model's reply, we can decode just the newly generated tokens
|
| 873 |
new_tokens = outputs[0][input_ids.shape[-1]:]
|
| 874 |
assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 875 |
-
|
| 876 |
|
| 877 |
num_input_tokens = input_ids.shape[-1] # This gets the sequence length (number of tokens)
|
| 878 |
num_generated_tokens = len(new_tokens)
|
|
@@ -887,12 +888,32 @@ def call_transformers_model(prompt: str, system_prompt: str, gen_config: LlamaCP
|
|
| 887 |
return assistant_reply, num_input_tokens, num_generated_tokens
|
| 888 |
|
| 889 |
# Function to send a request and update history
|
| 890 |
-
def send_request(prompt: str, conversation_history: List[dict],
|
| 891 |
-
"""
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
"""
|
| 897 |
# Constructing the full prompt from the conversation history
|
| 898 |
full_prompt = "Conversation history:\n"
|
|
@@ -920,7 +941,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 920 |
try:
|
| 921 |
print("Calling Gemini model, attempt", i + 1)
|
| 922 |
|
| 923 |
-
response =
|
| 924 |
|
| 925 |
#print("Successful call to Gemini model.")
|
| 926 |
break
|
|
@@ -948,18 +969,29 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 948 |
|
| 949 |
if i == number_of_api_retry_attempts:
|
| 950 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
| 951 |
-
elif "Azure" in model_source:
|
| 952 |
for i in progress_bar:
|
| 953 |
try:
|
| 954 |
-
print("Calling Azure
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 962 |
)
|
|
|
|
| 963 |
response_text = response_raw.choices[0].message.content
|
| 964 |
usage = getattr(response_raw, "usage", None)
|
| 965 |
input_tokens = 0
|
|
@@ -973,7 +1005,7 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 973 |
)
|
| 974 |
break
|
| 975 |
except Exception as e:
|
| 976 |
-
print("Call to Azure model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
| 977 |
time.sleep(timeout_wait)
|
| 978 |
if i == number_of_api_retry_attempts:
|
| 979 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
|
@@ -993,7 +1025,6 @@ def send_request(prompt: str, conversation_history: List[dict], google_client: a
|
|
| 993 |
response, num_transformer_input_tokens, num_transformer_generated_tokens = call_transformers_model(prompt, system_prompt, gen_config, model=local_model, tokenizer=tokenizer, assistant_model=assistant_model)
|
| 994 |
response_text = response
|
| 995 |
|
| 996 |
-
#print("Successful call to local model.")
|
| 997 |
break
|
| 998 |
except Exception as e:
|
| 999 |
# If fails, try again after X seconds in case there is a throttle limit
|
|
@@ -1035,7 +1066,7 @@ system_prompt: str,
|
|
| 1035 |
conversation_history: List[dict],
|
| 1036 |
whole_conversation: List[str],
|
| 1037 |
whole_conversation_metadata: List[str],
|
| 1038 |
-
|
| 1039 |
config: types.GenerateContentConfig,
|
| 1040 |
model_choice: str,
|
| 1041 |
temperature: float,
|
|
@@ -1056,7 +1087,7 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
|
|
| 1056 |
conversation_history (List[dict]): The history of the conversation.
|
| 1057 |
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
| 1058 |
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
| 1059 |
-
|
| 1060 |
config (dict): Configuration for the model.
|
| 1061 |
model_choice (str): The choice of model to use.
|
| 1062 |
temperature (float): The temperature parameter for the model.
|
|
@@ -1077,22 +1108,15 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
|
|
| 1077 |
|
| 1078 |
for prompt in prompts:
|
| 1079 |
|
| 1080 |
-
response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens = send_request(prompt, conversation_history,
|
| 1081 |
|
| 1082 |
responses.append(response)
|
| 1083 |
whole_conversation.append(system_prompt)
|
| 1084 |
whole_conversation.append(prompt)
|
| 1085 |
whole_conversation.append(response_text)
|
| 1086 |
|
| 1087 |
-
# Create conversation metadata
|
| 1088 |
-
# if master == False:
|
| 1089 |
-
# whole_conversation_metadata.append(f"Batch {batch_no}:")
|
| 1090 |
-
# else:
|
| 1091 |
-
# #whole_conversation_metadata.append(f"Query summary metadata:")
|
| 1092 |
-
|
| 1093 |
whole_conversation_metadata.append(f"Batch {batch_no}:")
|
| 1094 |
|
| 1095 |
-
# if not isinstance(response, str):
|
| 1096 |
try:
|
| 1097 |
if "AWS" in model_source:
|
| 1098 |
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
|
@@ -1102,7 +1126,7 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
|
|
| 1102 |
output_tokens = response.usage_metadata.candidates_token_count
|
| 1103 |
input_tokens = response.usage_metadata.prompt_token_count
|
| 1104 |
|
| 1105 |
-
elif "Azure" in model_source:
|
| 1106 |
input_tokens = response.usage_metadata.get('inputTokens', 0)
|
| 1107 |
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
| 1108 |
|
|
@@ -1123,9 +1147,6 @@ assistant_prefill="") -> Tuple[List[ResponseObject], List[dict], List[str], List
|
|
| 1123 |
|
| 1124 |
except KeyError as e:
|
| 1125 |
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
| 1126 |
-
# else:
|
| 1127 |
-
# print("Response is a string object.")
|
| 1128 |
-
# whole_conversation_metadata.append("Length prompt: " + str(len(prompt)) + ". Length response: " + str(len(response)))
|
| 1129 |
|
| 1130 |
return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
|
| 1131 |
|
|
@@ -1134,8 +1155,8 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
| 1134 |
conversation_history: List[dict],
|
| 1135 |
whole_conversation: List[str],
|
| 1136 |
whole_conversation_metadata: List[str],
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
model_choice: str,
|
| 1140 |
temperature: float,
|
| 1141 |
reported_batch_no: int,
|
|
@@ -1157,8 +1178,8 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
| 1157 |
- conversation_history (List[dict]): The history of the conversation.
|
| 1158 |
- whole_conversation (List[str]): The complete conversation including prompts and responses.
|
| 1159 |
- whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
| 1160 |
-
-
|
| 1161 |
-
-
|
| 1162 |
- model_choice (str): The choice of model to use.
|
| 1163 |
- temperature (float): The temperature parameter for the model.
|
| 1164 |
- reported_batch_no (int): The reported batch number.
|
|
@@ -1179,13 +1200,13 @@ def call_llm_with_markdown_table_checks(batch_prompts: List[str],
|
|
| 1179 |
call_temperature = temperature # This is correct now with the fixed parameter name
|
| 1180 |
|
| 1181 |
# Update Gemini config with the new temperature settings
|
| 1182 |
-
|
| 1183 |
|
| 1184 |
for attempt in range(MAX_OUTPUT_VALIDATION_ATTEMPTS):
|
| 1185 |
# Process requests to large language model
|
| 1186 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
|
| 1187 |
batch_prompts, system_prompt, conversation_history, whole_conversation,
|
| 1188 |
-
whole_conversation_metadata,
|
| 1189 |
call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, tokenizer=tokenizer, master=master, assistant_prefill=assistant_prefill
|
| 1190 |
)
|
| 1191 |
|
|
|
|
| 10 |
from google import genai as ai
|
| 11 |
from google.genai import types
|
| 12 |
from gradio import Progress
|
| 13 |
+
from openai import OpenAI
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
model_type = None # global variable setup
|
| 16 |
full_text = "" # Define dummy source text (full text) just to enable highlight function to load
|
|
|
|
| 671 |
|
| 672 |
def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
|
| 673 |
"""
|
| 674 |
+
Constructs an OpenAI client for Azure/OpenAI AI Inference.
|
| 675 |
"""
|
| 676 |
try:
|
| 677 |
key = None
|
| 678 |
if in_api_key:
|
| 679 |
key = in_api_key
|
| 680 |
+
elif os.environ.get("AZURE_OPENAI_API_KEY"):
|
| 681 |
+
key = os.environ["AZURE_OPENAI_API_KEY"]
|
|
|
|
|
|
|
| 682 |
if not key:
|
| 683 |
+
raise Warning("No Azure/OpenAI API key found.")
|
| 684 |
|
| 685 |
if not endpoint:
|
| 686 |
+
endpoint = os.environ.get("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
|
| 687 |
if not endpoint:
|
| 688 |
+
raise Warning("No Azure/OpenAI inference endpoint found.")
|
| 689 |
+
|
| 690 |
+
# Use the provided endpoint instead of hardcoded value
|
| 691 |
+
client = OpenAI(
|
| 692 |
+
api_key=key,
|
| 693 |
+
base_url=f"{endpoint}",
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
return client, dict()
|
| 697 |
except Exception as e:
|
| 698 |
+
print("Error constructing Azure/OpenAI client:", e)
|
| 699 |
raise
|
| 700 |
|
| 701 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice:str, bedrock_runtime:boto3.Session.client, assistant_prefill:str="") -> ResponseObject:
|
|
|
|
| 757 |
)
|
| 758 |
|
| 759 |
output_message = api_response['output']['message']
|
| 760 |
+
|
| 761 |
+
if 'reasoningContent' in output_message['content'][0]:
|
| 762 |
+
# Extract the reasoning text
|
| 763 |
+
reasoning_text = output_message['content'][0]['reasoningContent']['reasoningText']['text']
|
| 764 |
+
|
| 765 |
+
# Extract the output text
|
| 766 |
+
text = assistant_prefill + output_message['content'][1]['text']
|
| 767 |
+
else:
|
| 768 |
+
text = assistant_prefill + output_message['content'][0]['text']
|
| 769 |
|
| 770 |
# The usage statistics are neatly provided in the 'usage' key.
|
| 771 |
usage = api_response['usage']
|
|
|
|
| 812 |
{"role": "system", "content": system_prompt},
|
| 813 |
{"role": "user", "content": prompt}
|
| 814 |
]
|
|
|
|
|
|
|
|
|
|
| 815 |
|
| 816 |
# 2. Apply the chat template
|
| 817 |
# This function formats the conversation into the exact string Gemma 3 expects.
|
|
|
|
| 826 |
).to("cuda")
|
| 827 |
except Exception as e:
|
| 828 |
print("Error applying chat template:", e)
|
|
|
|
|
|
|
|
|
|
| 829 |
raise
|
| 830 |
|
| 831 |
# Map LlamaCPP parameters to transformers parameters
|
|
|
|
| 853 |
|
| 854 |
# Use speculative decoding if assistant model is available
|
| 855 |
if speculative_decoding and assistant_model is not None:
|
| 856 |
+
#print("Using speculative decoding with assistant model")
|
| 857 |
outputs = model.generate(
|
| 858 |
input_ids,
|
| 859 |
assistant_model=assistant_model,
|
|
|
|
| 861 |
streamer = streamer
|
| 862 |
)
|
| 863 |
else:
|
| 864 |
+
#print("Generating without speculative decoding")
|
| 865 |
outputs = model.generate(
|
| 866 |
input_ids,
|
| 867 |
**generation_kwargs,
|
|
|
|
| 871 |
end_time = time.time()
|
| 872 |
|
| 873 |
# --- Decode and Display Results ---
|
|
|
|
|
|
|
| 874 |
new_tokens = outputs[0][input_ids.shape[-1]:]
|
| 875 |
assistant_reply = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 876 |
+
|
| 877 |
|
| 878 |
num_input_tokens = input_ids.shape[-1] # This gets the sequence length (number of tokens)
|
| 879 |
num_generated_tokens = len(new_tokens)
|
|
|
|
| 888 |
return assistant_reply, num_input_tokens, num_generated_tokens
|
| 889 |
|
| 890 |
# Function to send a request and update history
|
| 891 |
+
def send_request(prompt: str, conversation_history: List[dict], client: ai.Client | OpenAI, config: types.GenerateContentConfig, model_choice: str, system_prompt: str, temperature: float, bedrock_runtime:boto3.Session.client, model_source:str, local_model= list(), tokenizer=None, assistant_model=None, assistant_prefill = "", progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
| 892 |
+
"""Sends a request to a language model and manages the conversation history.
|
| 893 |
+
|
| 894 |
+
This function constructs the full prompt by appending the new user prompt to the conversation history,
|
| 895 |
+
generates a response from the model, and updates the conversation history with the new prompt and response.
|
| 896 |
+
It handles different model sources (Gemini, AWS, Local) and includes retry logic for API calls.
|
| 897 |
+
|
| 898 |
+
Args:
|
| 899 |
+
prompt (str): The user's input prompt to be sent to the model.
|
| 900 |
+
conversation_history (List[dict]): A list of dictionaries representing the ongoing conversation.
|
| 901 |
+
Each dictionary should have 'role' and 'parts' keys.
|
| 902 |
+
client (ai.Client): The API client object for the chosen model (e.g., Gemini `ai.Client`, or Azure/OpenAI `OpenAI`).
|
| 903 |
+
config (types.GenerateContentConfig): Configuration settings for content generation (e.g., Gemini `types.GenerateContentConfig`).
|
| 904 |
+
model_choice (str): The specific model identifier to use (e.g., "gemini-pro", "claude-v2").
|
| 905 |
+
system_prompt (str): An optional system-level instruction or context for the model.
|
| 906 |
+
temperature (float): Controls the randomness of the model's output, with higher values leading to more diverse responses.
|
| 907 |
+
bedrock_runtime (boto3.Session.client): The boto3 Bedrock runtime client object for AWS models.
|
| 908 |
+
model_source (str): Indicates the source/provider of the model (e.g., "Gemini", "AWS", "Local").
|
| 909 |
+
local_model (list, optional): A list containing the local model and its tokenizer (if `model_source` is "Local"). Defaults to [].
|
| 910 |
+
tokenizer (object, optional): The tokenizer object for local models. Defaults to None.
|
| 911 |
+
assistant_model (object, optional): An optional assistant model used for speculative decoding with local models. Defaults to None.
|
| 912 |
+
assistant_prefill (str, optional): A string to pre-fill the assistant's response, useful for certain models like Claude. Defaults to "".
|
| 913 |
+
progress (Progress, optional): A progress object for tracking the operation, typically from `tqdm`. Defaults to Progress(track_tqdm=True).
|
| 914 |
+
|
| 915 |
+
Returns:
|
| 916 |
+
Tuple[str, List[dict]]: A tuple containing the model's response text and the updated conversation history.
|
| 917 |
"""
|
| 918 |
# Constructing the full prompt from the conversation history
|
| 919 |
full_prompt = "Conversation history:\n"
|
|
|
|
| 941 |
try:
|
| 942 |
print("Calling Gemini model, attempt", i + 1)
|
| 943 |
|
| 944 |
+
response = client.models.generate_content(model=model_choice, contents=full_prompt, config=config)
|
| 945 |
|
| 946 |
#print("Successful call to Gemini model.")
|
| 947 |
break
|
|
|
|
| 969 |
|
| 970 |
if i == number_of_api_retry_attempts:
|
| 971 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
| 972 |
+
elif "Azure/OpenAI" in model_source:
|
| 973 |
for i in progress_bar:
|
| 974 |
try:
|
| 975 |
+
print("Calling Azure/OpenAI inference model, attempt", i + 1)
|
| 976 |
+
|
| 977 |
+
messages=[
|
| 978 |
+
{
|
| 979 |
+
"role": "system",
|
| 980 |
+
"content": system_prompt,
|
| 981 |
+
},
|
| 982 |
+
{
|
| 983 |
+
"role": "user",
|
| 984 |
+
"content": prompt,
|
| 985 |
+
},
|
| 986 |
+
]
|
| 987 |
+
|
| 988 |
+
response_raw = client.chat.completions.create(
|
| 989 |
+
messages=messages,
|
| 990 |
+
model=model_choice,
|
| 991 |
+
temperature=temperature,
|
| 992 |
+
max_completion_tokens=max_tokens
|
| 993 |
)
|
| 994 |
+
|
| 995 |
response_text = response_raw.choices[0].message.content
|
| 996 |
usage = getattr(response_raw, "usage", None)
|
| 997 |
input_tokens = 0
|
|
|
|
| 1005 |
)
|
| 1006 |
break
|
| 1007 |
except Exception as e:
|
| 1008 |
+
print("Call to Azure/OpenAI model failed:", e, " Waiting for ", str(timeout_wait), "seconds and trying again.")
|
| 1009 |
time.sleep(timeout_wait)
|
| 1010 |
if i == number_of_api_retry_attempts:
|
| 1011 |
return ResponseObject(text="", usage_metadata={'RequestId':"FAILED"}), conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens
|
|
|
|
| 1025 |
response, num_transformer_input_tokens, num_transformer_generated_tokens = call_transformers_model(prompt, system_prompt, gen_config, model=local_model, tokenizer=tokenizer, assistant_model=assistant_model)
|
| 1026 |
response_text = response
|
| 1027 |
|
|
|
|
| 1028 |
break
|
| 1029 |
except Exception as e:
|
| 1030 |
# If fails, try again after X seconds in case there is a throttle limit
|
|
|
|
| 1066 |
conversation_history: List[dict],
|
| 1067 |
whole_conversation: List[str],
|
| 1068 |
whole_conversation_metadata: List[str],
|
| 1069 |
+
client: ai.Client | OpenAI,
|
| 1070 |
config: types.GenerateContentConfig,
|
| 1071 |
model_choice: str,
|
| 1072 |
temperature: float,
|
|
|
|
| 1087 |
conversation_history (List[dict]): The history of the conversation.
|
| 1088 |
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
| 1089 |
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
| 1090 |
+
client (object): The client to use for processing the prompts, from either Gemini or OpenAI client.
|
| 1091 |
config (dict): Configuration for the model.
|
| 1092 |
model_choice (str): The choice of model to use.
|
| 1093 |
temperature (float): The temperature parameter for the model.
|
|
|
|
| 1108 |
|
| 1109 |
for prompt in prompts:
|
| 1110 |
|
| 1111 |
+
response, conversation_history, response_text, num_transformer_input_tokens, num_transformer_generated_tokens = send_request(prompt, conversation_history, client=client, config=config, model_choice=model_choice, system_prompt=system_prompt, temperature=temperature, local_model=local_model, tokenizer=tokenizer, assistant_model=assistant_model, assistant_prefill=assistant_prefill, bedrock_runtime=bedrock_runtime, model_source=model_source)
|
| 1112 |
|
| 1113 |
responses.append(response)
|
| 1114 |
whole_conversation.append(system_prompt)
|
| 1115 |
whole_conversation.append(prompt)
|
| 1116 |
whole_conversation.append(response_text)
|
| 1117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
whole_conversation_metadata.append(f"Batch {batch_no}:")
|
| 1119 |
|
|
|
|
| 1120 |
try:
|
| 1121 |
if "AWS" in model_source:
|
| 1122 |
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
|
|
|
| 1126 |
output_tokens = response.usage_metadata.candidates_token_count
|
| 1127 |
input_tokens = response.usage_metadata.prompt_token_count
|
| 1128 |
|
| 1129 |
+
elif "Azure/OpenAI" in model_source:
|
| 1130 |
input_tokens = response.usage_metadata.get('inputTokens', 0)
|
| 1131 |
output_tokens = response.usage_metadata.get('outputTokens', 0)
|
| 1132 |
|
|
|
|
| 1147 |
|
| 1148 |
except KeyError as e:
|
| 1149 |
print(f"Key error: {e} - Check the structure of response.usage_metadata")
|
|
|
|
|
|
|
|
|
|
| 1150 |
|
| 1151 |
return responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text
|
| 1152 |
|
|
|
|
| 1155 |
conversation_history: List[dict],
|
| 1156 |
whole_conversation: List[str],
|
| 1157 |
whole_conversation_metadata: List[str],
|
| 1158 |
+
client: ai.Client | OpenAI,
|
| 1159 |
+
client_config: types.GenerateContentConfig,
|
| 1160 |
model_choice: str,
|
| 1161 |
temperature: float,
|
| 1162 |
reported_batch_no: int,
|
|
|
|
| 1178 |
- conversation_history (List[dict]): The history of the conversation.
|
| 1179 |
- whole_conversation (List[str]): The complete conversation including prompts and responses.
|
| 1180 |
- whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
| 1181 |
+
- client (ai.Client | OpenAI): The client object for running Gemini or Azure/OpenAI API calls.
|
| 1182 |
+
- client_config (types.GenerateContentConfig): Configuration for the model.
|
| 1183 |
- model_choice (str): The choice of model to use.
|
| 1184 |
- temperature (float): The temperature parameter for the model.
|
| 1185 |
- reported_batch_no (int): The reported batch number.
|
|
|
|
| 1200 |
call_temperature = temperature # This is correct now with the fixed parameter name
|
| 1201 |
|
| 1202 |
# Update Gemini config with the new temperature settings
|
| 1203 |
+
client_config = types.GenerateContentConfig(temperature=call_temperature, max_output_tokens=max_tokens, seed=random_seed)
|
| 1204 |
|
| 1205 |
for attempt in range(MAX_OUTPUT_VALIDATION_ATTEMPTS):
|
| 1206 |
# Process requests to large language model
|
| 1207 |
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = process_requests(
|
| 1208 |
batch_prompts, system_prompt, conversation_history, whole_conversation,
|
| 1209 |
+
whole_conversation_metadata, client, client_config, model_choice,
|
| 1210 |
call_temperature, bedrock_runtime, model_source, reported_batch_no, local_model, tokenizer=tokenizer, master=master, assistant_prefill=assistant_prefill
|
| 1211 |
)
|
| 1212 |
|