seanpedrickcase commited on
Commit
d6ff533
·
1 Parent(s): 1b35214

Added 'all-in-one' function. Corrected local model load when not loaded initially. Environment variables for max data rows and topics.

Browse files
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  from datetime import datetime
6
  from tools.helper_functions import put_columns_in_df, get_connection_params, get_or_create_env_var, reveal_feedback_buttons, wipe_logs, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df, load_in_previous_data_files, load_in_data_file, load_in_default_cost_codes, reset_base_dataframe, update_cost_code_dataframe_from_dropdown_select, df_select_callback_cost, enforce_cost_codes, _get_env_list, move_overall_summary_output_files_to_front_page
7
  from tools.aws_functions import upload_file_to_s3, download_file_from_s3
8
- from tools.llm_api_call import modify_existing_output_tables, wrapper_extract_topics_per_column_value
9
  from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary
10
  from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
11
  from tools.custom_csvlogger import CSVLogger_custom
@@ -66,9 +66,9 @@ with app:
66
  # STATE VARIABLES
67
  ###
68
 
69
- text_output_file_list_state = gr.Dropdown([], allow_custom_value=True, visible=False, label="text_output_file_list_state")
70
- text_output_modify_file_list_state = gr.Dropdown([], allow_custom_value=True, visible=False, label="text_output_modify_file_list_state")
71
- log_files_output_list_state = gr.Dropdown([], allow_custom_value=True, visible=False, label="log_files_output_list_state")
72
  first_loop_state = gr.Checkbox(True, visible=False)
73
  second_loop_state = gr.Checkbox(False, visible=False)
74
  modified_unique_table_change_bool = gr.Checkbox(True, visible=False) # This boolean is used to flag whether a file upload should change just the modified unique table object on the second tab
@@ -321,8 +321,6 @@ with app:
321
  number_of_prompts = gr.Number(value=1, label="Number of prompts to send to LLM in sequence", minimum=1, maximum=3, visible=False)
322
  system_prompt_textbox = gr.Textbox(label="Initial system prompt", lines = 4, value = system_prompt)
323
  initial_table_prompt_textbox = gr.Textbox(label = "Initial topics prompt", lines = 8, value = initial_table_prompt)
324
- prompt_2_textbox = gr.Textbox(label = "Prompt 2", lines = 8, value = prompt2, visible=False)
325
- prompt_3_textbox = gr.Textbox(label = "Prompt 3", lines = 8, value = prompt3, visible=False)
326
  add_to_existing_topics_system_prompt_textbox = gr.Textbox(label="Additional topics system prompt", lines = 4, value = add_existing_topics_system_prompt)
327
  add_to_existing_topics_prompt_textbox = gr.Textbox(label = "Additional topics prompt", lines = 8, value = add_existing_topics_prompt)
328
  verify_titles_system_prompt_textbox = gr.Textbox(label="Verify descriptions system prompt", lines = 4, value = verify_titles_system_prompt, visible=False)
@@ -392,8 +390,6 @@ with app:
392
  latest_batch_completed,
393
  estimated_time_taken_number,
394
  initial_table_prompt_textbox,
395
- prompt_2_textbox,
396
- prompt_3_textbox,
397
  system_prompt_textbox,
398
  add_to_existing_topics_system_prompt_textbox,
399
  add_to_existing_topics_prompt_textbox,
@@ -433,7 +429,7 @@ with app:
433
  output_messages_textbox],
434
  api_name="extract_topics", show_progress_on=output_messages_textbox).\
435
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs").\
436
- success(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[topic_extraction_output_files_xlsx, summary_xlsx_output_files_list])
437
 
438
  ###
439
  # DEDUPLICATION AND SUMMARISATION FUNCTIONS
@@ -455,14 +451,14 @@ with app:
455
  success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
456
  success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox], api_name="summarise_topics", show_progress_on=[output_messages_textbox, summary_output_files]).\
457
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
458
- success(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_revised_summaries_state, master_unique_topics_df_revised_summaries_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[summary_output_files_xlsx, summary_xlsx_output_files_list])
459
 
460
  # SUMMARISE WHOLE TABLE PAGE
461
  overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
462
  success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
463
  success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox], scroll_to_output=True, api_name="overall_summary", show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
464
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
465
- success(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list])
466
 
467
 
468
  # All in one button
@@ -471,8 +467,9 @@ with app:
471
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
472
  success(load_in_data_file,
473
  inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
474
- success(fn=wrapper_extract_topics_per_column_value,
475
- inputs=[in_group_col,
 
476
  in_data_files,
477
  file_data_state,
478
  master_topic_df_state,
@@ -489,10 +486,8 @@ with app:
489
  first_loop_state,
490
  conversation_metadata_textbox,
491
  latest_batch_completed,
492
- estimated_time_taken_number,
493
  initial_table_prompt_textbox,
494
- prompt_2_textbox,
495
- prompt_3_textbox,
496
  system_prompt_textbox,
497
  add_to_existing_topics_system_prompt_textbox,
498
  add_to_existing_topics_prompt_textbox,
@@ -508,8 +503,18 @@ with app:
508
  aws_secret_key_textbox,
509
  hf_api_key_textbox,
510
  azure_api_key_textbox,
511
- output_folder_state],
512
- outputs=[display_topic_table_markdown,
 
 
 
 
 
 
 
 
 
 
513
  master_topic_df_state,
514
  master_unique_topics_df_state,
515
  master_reference_df_state,
@@ -529,19 +534,27 @@ with app:
529
  input_tokens_num,
530
  output_tokens_num,
531
  number_of_calls_num,
532
- output_messages_textbox], show_progress_on=output_messages_textbox).\
533
- success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
534
- success(load_in_previous_data_files, inputs=[deduplication_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
535
- success(deduplicate_topics, inputs=[master_reference_df_state, master_unique_topics_df_state, working_data_file_name_textbox, unique_topics_table_file_name_textbox, in_excel_sheets, merge_sentiment_drop, merge_general_topics_drop, deduplicate_score_threshold, in_data_files, in_colnames, output_folder_state], outputs=[master_reference_df_state, master_unique_topics_df_state, summarisation_input_files, log_files_output, summarised_output_markdown]).\
536
- success(load_in_previous_data_files, inputs=[summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
537
- success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown]).\
538
- success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, display_topic_table_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox], show_progress_on=[output_messages_textbox, summary_output_files]).\
539
- success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
540
- success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
541
- success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox], show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
542
- success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
543
- success(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list]).\
544
- success(move_overall_summary_output_files_to_front_page, inputs=[summary_xlsx_output_files_list], outputs=[topic_extraction_output_files_xlsx])
 
 
 
 
 
 
 
 
545
 
546
  ###
547
  # CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
@@ -563,7 +576,7 @@ with app:
563
  success(load_in_data_file,
564
  inputs = [verify_in_data_files, verify_in_colnames, batch_size_number, verify_in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="verify_load_data").\
565
  success(fn=verify_titles,
566
- inputs=[verify_in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, original_data_file_name_textbox, total_number_of_batches, verify_in_api_key, temperature_slide, verify_in_colnames, verify_model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, first_loop_state, conversation_metadata_textbox, verify_titles_prompt_textbox, prompt_2_textbox, prompt_3_textbox, verify_titles_system_prompt_textbox, verify_titles_system_prompt_textbox, verify_titles_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, produce_structures_summary_radio, aws_access_key_textbox, aws_secret_key_textbox, in_excel_sheets, output_folder_state],
567
  outputs=[verify_display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, verify_titles_file_output, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, verify_modification_input_files_placeholder], api_name="verify_descriptions")
568
 
569
  ###
 
5
  from datetime import datetime
6
  from tools.helper_functions import put_columns_in_df, get_connection_params, get_or_create_env_var, reveal_feedback_buttons, wipe_logs, view_table, empty_output_vars_extract_topics, empty_output_vars_summarise, load_in_previous_reference_file, join_cols_onto_reference_df, load_in_previous_data_files, load_in_data_file, load_in_default_cost_codes, reset_base_dataframe, update_cost_code_dataframe_from_dropdown_select, df_select_callback_cost, enforce_cost_codes, _get_env_list, move_overall_summary_output_files_to_front_page
7
  from tools.aws_functions import upload_file_to_s3, download_file_from_s3
8
+ from tools.llm_api_call import modify_existing_output_tables, wrapper_extract_topics_per_column_value, all_in_one_pipeline
9
  from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary
10
  from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
11
  from tools.custom_csvlogger import CSVLogger_custom
 
66
  # STATE VARIABLES
67
  ###
68
 
69
+ text_output_file_list_state = gr.Dropdown(list(), allow_custom_value=True, visible=False, label="text_output_file_list_state")
70
+ text_output_modify_file_list_state = gr.Dropdown(list(), allow_custom_value=True, visible=False, label="text_output_modify_file_list_state")
71
+ log_files_output_list_state = gr.Dropdown(list(), allow_custom_value=True, visible=False, label="log_files_output_list_state")
72
  first_loop_state = gr.Checkbox(True, visible=False)
73
  second_loop_state = gr.Checkbox(False, visible=False)
74
  modified_unique_table_change_bool = gr.Checkbox(True, visible=False) # This boolean is used to flag whether a file upload should change just the modified unique table object on the second tab
 
321
  number_of_prompts = gr.Number(value=1, label="Number of prompts to send to LLM in sequence", minimum=1, maximum=3, visible=False)
322
  system_prompt_textbox = gr.Textbox(label="Initial system prompt", lines = 4, value = system_prompt)
323
  initial_table_prompt_textbox = gr.Textbox(label = "Initial topics prompt", lines = 8, value = initial_table_prompt)
 
 
324
  add_to_existing_topics_system_prompt_textbox = gr.Textbox(label="Additional topics system prompt", lines = 4, value = add_existing_topics_system_prompt)
325
  add_to_existing_topics_prompt_textbox = gr.Textbox(label = "Additional topics prompt", lines = 8, value = add_existing_topics_prompt)
326
  verify_titles_system_prompt_textbox = gr.Textbox(label="Verify descriptions system prompt", lines = 4, value = verify_titles_system_prompt, visible=False)
 
390
  latest_batch_completed,
391
  estimated_time_taken_number,
392
  initial_table_prompt_textbox,
 
 
393
  system_prompt_textbox,
394
  add_to_existing_topics_system_prompt_textbox,
395
  add_to_existing_topics_prompt_textbox,
 
429
  output_messages_textbox],
430
  api_name="extract_topics", show_progress_on=output_messages_textbox).\
431
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False, api_name="usage_logs").\
432
+ then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[topic_extraction_output_files_xlsx, summary_xlsx_output_files_list])
433
 
434
  ###
435
  # DEDUPLICATION AND SUMMARISATION FUNCTIONS
 
451
  success(sample_reference_table_summaries, inputs=[master_reference_df_state, random_seed], outputs=[summary_reference_table_sample_state, summarised_references_markdown], api_name="sample_summaries").\
452
  success(summarise_output_topics, inputs=[summary_reference_table_sample_state, master_unique_topics_df_state, master_reference_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, in_data_files, in_excel_sheets, in_colnames, log_files_output_list_state, summarise_format_radio, output_folder_state, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox], outputs=[summary_reference_table_sample_state, master_unique_topics_df_revised_summaries_state, master_reference_df_revised_summaries_state, summary_output_files, summarised_outputs_list, latest_summary_completed_num, conversation_metadata_textbox, summarised_output_markdown, log_files_output, overall_summarisation_input_files, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox], api_name="summarise_topics", show_progress_on=[output_messages_textbox, summary_output_files]).\
453
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
454
+ then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_revised_summaries_state, master_unique_topics_df_revised_summaries_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[summary_output_files_xlsx, summary_xlsx_output_files_list])
455
 
456
  # SUMMARISE WHOLE TABLE PAGE
457
  overall_summarise_previous_data_btn.click(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
458
  success(load_in_previous_data_files, inputs=[overall_summarisation_input_files], outputs=[master_reference_df_state, master_unique_topics_df_state, latest_batch_completed_no_loop, deduplication_input_files_status, working_data_file_name_textbox, unique_topics_table_file_name_textbox]).\
459
  success(overall_summary, inputs=[master_unique_topics_df_state, model_choice, google_api_key_textbox, temperature_slide, working_data_file_name_textbox, output_folder_state, in_colnames, context_textbox, aws_access_key_textbox, aws_secret_key_textbox, model_name_map_state, hf_api_key_textbox], outputs=[overall_summary_output_files, overall_summarised_output_markdown, summarised_output_df, conversation_metadata_textbox, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, output_messages_textbox], scroll_to_output=True, api_name="overall_summary", show_progress_on=[output_messages_textbox, overall_summary_output_files]).\
460
  success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
461
+ then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list])
462
 
463
 
464
  # All in one button
 
467
  success(fn= enforce_cost_codes, inputs=[enforce_cost_code_textbox, cost_code_choice_drop, cost_code_dataframe_base]).\
468
  success(load_in_data_file,
469
  inputs = [in_data_files, in_colnames, batch_size_number, in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="load_data").\
470
+ success(fn=all_in_one_pipeline,
471
+ inputs=[
472
+ in_group_col,
473
  in_data_files,
474
  file_data_state,
475
  master_topic_df_state,
 
486
  first_loop_state,
487
  conversation_metadata_textbox,
488
  latest_batch_completed,
489
+ estimated_time_taken_number,
490
  initial_table_prompt_textbox,
 
 
491
  system_prompt_textbox,
492
  add_to_existing_topics_system_prompt_textbox,
493
  add_to_existing_topics_prompt_textbox,
 
503
  aws_secret_key_textbox,
504
  hf_api_key_textbox,
505
  azure_api_key_textbox,
506
+ output_folder_state,
507
+ merge_sentiment_drop,
508
+ merge_general_topics_drop,
509
+ deduplicate_score_threshold,
510
+ summarise_format_radio,
511
+ random_seed,
512
+ log_files_output_list_state,
513
+ model_name_map_state,
514
+ usage_logs_state
515
+ ],
516
+ outputs=[
517
+ display_topic_table_markdown,
518
  master_topic_df_state,
519
  master_unique_topics_df_state,
520
  master_reference_df_state,
 
534
  input_tokens_num,
535
  output_tokens_num,
536
  number_of_calls_num,
537
+ output_messages_textbox,
538
+ summary_reference_table_sample_state,
539
+ summarised_references_markdown,
540
+ master_unique_topics_df_revised_summaries_state,
541
+ master_reference_df_revised_summaries_state,
542
+ summary_output_files,
543
+ summarised_outputs_list,
544
+ latest_summary_completed_num,
545
+ overall_summarisation_input_files,
546
+ overall_summary_output_files,
547
+ overall_summarised_output_markdown,
548
+ summarised_output_df#,
549
+ # overall_summary_output_files_xlsx,
550
+ # summary_xlsx_output_files_list,
551
+ # topic_extraction_output_files_xlsx
552
+ ],
553
+ show_progress_on=[output_messages_textbox], api_name="all_in_one_pipeline"
554
+ ).\
555
+ success(lambda *args: usage_callback.flag(list(args), save_to_csv=SAVE_LOGS_TO_CSV, save_to_dynamodb=SAVE_LOGS_TO_DYNAMODB, dynamodb_table_name=USAGE_LOG_DYNAMODB_TABLE_NAME, dynamodb_headers=DYNAMODB_USAGE_LOG_HEADERS, replacement_headers=CSV_USAGE_LOG_HEADERS), [session_hash_textbox, original_data_file_name_textbox, in_colnames, model_choice, conversation_metadata_textbox_placeholder, input_tokens_num, output_tokens_num, number_of_calls_num, estimated_time_taken_number, cost_code_choice_drop], None, preprocess=False).\
556
+ then(collect_output_csvs_and_create_excel_output, inputs=[in_data_files, in_colnames, original_data_file_name_textbox, in_group_col, model_choice, master_reference_df_state, master_unique_topics_df_state, summarised_output_df, missing_df_state, in_excel_sheets, usage_logs_state, model_name_map_state, output_folder_state], outputs=[overall_summary_output_files_xlsx, summary_xlsx_output_files_list]).\
557
+ success(move_overall_summary_output_files_to_front_page, inputs=[summary_xlsx_output_files_list], outputs=[topic_extraction_output_files_xlsx])
558
 
559
  ###
560
  # CONTINUE PREVIOUS TOPIC EXTRACTION PAGE
 
576
  success(load_in_data_file,
577
  inputs = [verify_in_data_files, verify_in_colnames, batch_size_number, verify_in_excel_sheets], outputs = [file_data_state, working_data_file_name_textbox, total_number_of_batches], api_name="verify_load_data").\
578
  success(fn=verify_titles,
579
+ inputs=[verify_in_data_files, file_data_state, master_topic_df_state, master_reference_df_state, master_unique_topics_df_state, display_topic_table_markdown, original_data_file_name_textbox, total_number_of_batches, verify_in_api_key, temperature_slide, verify_in_colnames, verify_model_choice, candidate_topics, latest_batch_completed, display_topic_table_markdown, text_output_file_list_state, log_files_output_list_state, first_loop_state, conversation_metadata_textbox, verify_titles_prompt_textbox, verify_titles_system_prompt_textbox, verify_titles_system_prompt_textbox, verify_titles_prompt_textbox, number_of_prompts, batch_size_number, context_textbox, estimated_time_taken_number, sentiment_checkbox, force_zero_shot_radio, produce_structures_summary_radio, aws_access_key_textbox, aws_secret_key_textbox, in_excel_sheets, output_folder_state],
580
  outputs=[verify_display_topic_table_markdown, master_topic_df_state, master_unique_topics_df_state, master_reference_df_state, verify_titles_file_output, text_output_file_list_state, latest_batch_completed, log_files_output, log_files_output_list_state, conversation_metadata_textbox, estimated_time_taken_number, deduplication_input_files, summarisation_input_files, modifiable_unique_topics_df_state, verify_modification_input_files_placeholder], api_name="verify_descriptions")
581
 
582
  ###
tools/config.py CHANGED
@@ -294,7 +294,7 @@ GEMMA3_MODEL_FILE = get_or_create_env_var("GEMMA3_MODEL_FILE", "gemma-3-270m-it-
294
  GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
295
 
296
  GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
297
- GEMMA3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA3_4B_REPO_TRANSFORMERS_ID", "google/gemma-3-4b-it")
298
  if USE_LLAMA_CPP == "False":
299
  GEMMA3_4B_REPO_ID = GEMMA3_4B_REPO_TRANSFORMERS_ID
300
 
@@ -364,7 +364,14 @@ COMPILE_MODE = get_or_create_env_var('COMPILE_MODE', 'reduce-overhead') # altern
364
  MODEL_DTYPE = get_or_create_env_var('MODEL_DTYPE', 'bfloat16') # alternatively 'bfloat16'
365
  INT8_WITH_OFFLOAD_TO_CPU = get_or_create_env_var('INT8_WITH_OFFLOAD_TO_CPU', 'False') # Whether to offload to CPU
366
 
 
 
 
 
 
367
  MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
 
 
368
 
369
  ###
370
  # Gradio app variables
 
294
  GEMMA3_MODEL_FOLDER = get_or_create_env_var("GEMMA3_MODEL_FOLDER", "model/gemma")
295
 
296
  GEMMA3_4B_REPO_ID = get_or_create_env_var("GEMMA3_4B_REPO_ID", "unsloth/gemma-3-4b-it-qat-GGUF")
297
+ GEMMA3_4B_REPO_TRANSFORMERS_ID = get_or_create_env_var("GEMMA3_4B_REPO_TRANSFORMERS_ID", "unsloth/gemma-3-4b-it-qat-unsloth-bnb-4bit") # "google/gemma-3-4b-it"
298
  if USE_LLAMA_CPP == "False":
299
  GEMMA3_4B_REPO_ID = GEMMA3_4B_REPO_TRANSFORMERS_ID
300
 
 
364
  MODEL_DTYPE = get_or_create_env_var('MODEL_DTYPE', 'bfloat16') # alternatively 'bfloat16'
365
  INT8_WITH_OFFLOAD_TO_CPU = get_or_create_env_var('INT8_WITH_OFFLOAD_TO_CPU', 'False') # Whether to offload to CPU
366
 
367
+ ###
368
+ # Dataset variables
369
+ ###
370
+
371
+ MAX_ROWS = int(get_or_create_env_var('MAX_ROWS', '5000'))
372
  MAX_GROUPS = int(get_or_create_env_var('MAX_GROUPS', '99'))
373
+ MAXIMUM_ZERO_SHOT_TOPICS = int(get_or_create_env_var('MAXIMUM_ZERO_SHOT_TOPICS', '120'))
374
+ MAX_SPACES_GPU_RUN_TIME = int(get_or_create_env_var('MAX_SPACES_GPU_RUN_TIME', '240'))
375
 
376
  ###
377
  # Gradio app variables
tools/dedup_summaries.py CHANGED
@@ -14,7 +14,7 @@ from tools.prompts import summarise_topic_descriptions_prompt, summarise_topic_d
14
  from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
15
  from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details
16
  from tools.aws_functions import connect_to_bedrock_runtime
17
- from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT
18
 
19
  max_tokens = MAX_TOKENS
20
  timeout_wait = TIMEOUT_WAIT
@@ -464,7 +464,7 @@ def summarise_output_topics_query(model_choice:str, in_api_key:str, temperature:
464
 
465
  return response_text, conversation_history, whole_conversation_metadata
466
 
467
- @spaces.GPU(duration=300)
468
  def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
469
  topic_summary_df:pd.DataFrame,
470
  reference_table_df:pd.DataFrame,
@@ -487,9 +487,9 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
487
  model_name_map:dict=model_name_map,
488
  hf_api_key_textbox:str='',
489
  reasoning_suffix:str=reasoning_suffix,
490
- local_model:object=list(),
491
- tokenizer:object=list(),
492
- assistant_model:object=list(),
493
  summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt,
494
  summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
495
  do_summaries:str="Yes",
@@ -520,8 +520,9 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
520
  model_name_map (dict, optional): Dictionary mapping model choices to their properties. Defaults to model_name_map.
521
  hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
522
  reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
523
- local_model (object, optional): Local model object if using local inference. Defaults to empty list.
524
- tokenizer (object, optional): Tokenizer object if using local inference. Defaults to empty list.
 
525
  summarise_topic_descriptions_prompt (str, optional): Prompt template for topic summarization.
526
  summarise_topic_descriptions_system_prompt (str, optional): System prompt for topic summarization.
527
  do_summaries (str, optional): Flag to control summary generation. Defaults to "Yes".
@@ -579,7 +580,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
579
 
580
  model_source = model_name_map[model_choice]["source"]
581
 
582
- if (model_source == "Local") & (RUN_LOCAL_MODEL == "1") & (not local_model) & (not tokenizer):
583
  progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
584
  local_model = get_model()
585
  tokenizer = get_tokenizer()
@@ -695,7 +696,7 @@ def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
695
 
696
  return sampled_reference_table_df, topic_summary_df_revised, reference_table_df_revised, output_files, summarised_outputs, latest_summary_completed, out_metadata_str, summarised_output_markdown, log_output_files, output_files, acc_input_tokens, acc_output_tokens, acc_number_of_calls, time_taken, out_message
697
 
698
- @spaces.GPU(duration=120)
699
  def overall_summary(topic_summary_df:pd.DataFrame,
700
  model_choice:str,
701
  in_api_key:str,
@@ -709,9 +710,9 @@ def overall_summary(topic_summary_df:pd.DataFrame,
709
  model_name_map:dict=model_name_map,
710
  hf_api_key_textbox:str='',
711
  reasoning_suffix:str=reasoning_suffix,
712
- local_model:object=list(),
713
- tokenizer:object=list(),
714
- assistant_model:object=list(),
715
  summarise_everything_prompt:str=summarise_everything_prompt,
716
  comprehensive_summary_format_prompt:str=comprehensive_summary_format_prompt,
717
  comprehensive_summary_format_prompt_by_group:str=comprehensive_summary_format_prompt_by_group,
@@ -735,8 +736,9 @@ def overall_summary(topic_summary_df:pd.DataFrame,
735
  model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
736
  hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
737
  reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
738
- local_model (object, optional): Local model object. Defaults to empty list.
739
- tokenizer (object, optional): Tokenizer object. Defaults to empty list.
 
740
  summarise_everything_prompt (str, optional): Prompt for overall summary
741
  comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
742
  comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
@@ -790,7 +792,7 @@ def overall_summary(topic_summary_df:pd.DataFrame,
790
 
791
  tic = time.perf_counter()
792
 
793
- if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1") & (not local_model) & (not tokenizer):
794
  progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
795
  local_model = get_model()
796
  tokenizer = get_tokenizer()
 
14
  from tools.llm_funcs import construct_gemini_generative_model, process_requests, ResponseObject, load_model, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
15
  from tools.helper_functions import create_topic_summary_df_from_reference_table, load_in_data_file, get_basic_response_data, convert_reference_table_to_pivot_table, wrap_text, clean_column_name, get_file_name_no_ext, create_batch_file_path_details
16
  from tools.aws_functions import connect_to_bedrock_runtime
17
+ from tools.config import OUTPUT_FOLDER, RUN_LOCAL_MODEL, MAX_COMMENT_CHARS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT, MAX_SPACES_GPU_RUN_TIME
18
 
19
  max_tokens = MAX_TOKENS
20
  timeout_wait = TIMEOUT_WAIT
 
464
 
465
  return response_text, conversation_history, whole_conversation_metadata
466
 
467
+ @spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
468
  def summarise_output_topics(sampled_reference_table_df:pd.DataFrame,
469
  topic_summary_df:pd.DataFrame,
470
  reference_table_df:pd.DataFrame,
 
487
  model_name_map:dict=model_name_map,
488
  hf_api_key_textbox:str='',
489
  reasoning_suffix:str=reasoning_suffix,
490
+ local_model:object=None,
491
+ tokenizer:object=None,
492
+ assistant_model:object=None,
493
  summarise_topic_descriptions_prompt:str=summarise_topic_descriptions_prompt,
494
  summarise_topic_descriptions_system_prompt:str=summarise_topic_descriptions_system_prompt,
495
  do_summaries:str="Yes",
 
520
  model_name_map (dict, optional): Dictionary mapping model choices to their properties. Defaults to model_name_map.
521
  hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
522
  reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
523
+ local_model (object, optional): Local model object if using local inference. Defaults to None.
524
+ tokenizer (object, optional): Tokenizer object if using local inference. Defaults to None.
525
+ assistant_model (object, optional): Assistant model object if using local inference. Defaults to None.
526
  summarise_topic_descriptions_prompt (str, optional): Prompt template for topic summarization.
527
  summarise_topic_descriptions_system_prompt (str, optional): System prompt for topic summarization.
528
  do_summaries (str, optional): Flag to control summary generation. Defaults to "Yes".
 
580
 
581
  model_source = model_name_map[model_choice]["source"]
582
 
583
+ if (model_source == "Local") & (RUN_LOCAL_MODEL == "1") & (not local_model):
584
  progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
585
  local_model = get_model()
586
  tokenizer = get_tokenizer()
 
696
 
697
  return sampled_reference_table_df, topic_summary_df_revised, reference_table_df_revised, output_files, summarised_outputs, latest_summary_completed, out_metadata_str, summarised_output_markdown, log_output_files, output_files, acc_input_tokens, acc_output_tokens, acc_number_of_calls, time_taken, out_message
698
 
699
+ @spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
700
  def overall_summary(topic_summary_df:pd.DataFrame,
701
  model_choice:str,
702
  in_api_key:str,
 
710
  model_name_map:dict=model_name_map,
711
  hf_api_key_textbox:str='',
712
  reasoning_suffix:str=reasoning_suffix,
713
+ local_model:object=None,
714
+ tokenizer:object=None,
715
+ assistant_model:object=None,
716
  summarise_everything_prompt:str=summarise_everything_prompt,
717
  comprehensive_summary_format_prompt:str=comprehensive_summary_format_prompt,
718
  comprehensive_summary_format_prompt_by_group:str=comprehensive_summary_format_prompt_by_group,
 
736
  model_name_map (dict, optional): Mapping of model names. Defaults to model_name_map.
737
  hf_api_key_textbox (str, optional): Hugging Face API key. Defaults to empty string.
738
  reasoning_suffix (str, optional): Suffix for reasoning. Defaults to reasoning_suffix.
739
+ local_model (object, optional): Local model object. Defaults to None.
740
+ tokenizer (object, optional): Tokenizer object. Defaults to None.
741
+ assistant_model (object, optional): Assistant model object. Defaults to None.
742
  summarise_everything_prompt (str, optional): Prompt for overall summary
743
  comprehensive_summary_format_prompt (str, optional): Prompt for comprehensive summary format
744
  comprehensive_summary_format_prompt_by_group (str, optional): Prompt for group summary format
 
792
 
793
  tic = time.perf_counter()
794
 
795
+ if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1") & (not local_model):
796
  progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
797
  local_model = get_model()
798
  tokenizer = get_tokenizer()
tools/llm_api_call.py CHANGED
@@ -15,10 +15,12 @@ from io import StringIO
15
  GradioFileData = gr.FileData
16
 
17
  from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
18
- from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details
19
  from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
20
- from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT
21
  from tools.aws_functions import connect_to_bedrock_runtime
 
 
22
 
23
  if RUN_LOCAL_MODEL == "1":
24
  from tools.llm_funcs import load_model
@@ -32,6 +34,8 @@ deduplication_threshold = DEDUPLICATION_THRESHOLD
32
  max_comment_character_length = MAX_COMMENT_CHARS
33
  random_seed = LLM_SEED
34
  reasoning_suffix = REASONING_SUFFIX
 
 
35
 
36
  ### HELPER FUNCTIONS
37
 
@@ -538,7 +542,7 @@ def write_llm_output_and_logs(response_text: str,
538
  def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
539
  force_zero_shot_radio:str="No",
540
  create_revised_general_topics:bool=False,
541
- max_topic_no:int=120):
542
  """
543
  Preprocesses a DataFrame of zero-shot topics, cleaning and formatting them
544
  for use with a large language model. It handles different column configurations
@@ -572,8 +576,9 @@ def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
572
 
573
  # Max 120 topics allowed
574
  if zero_shot_topics.shape[0] > max_topic_no:
575
- print("Maximum", max_topic_no, "topics allowed to fit within large language model context limits.")
576
- zero_shot_topics = zero_shot_topics.iloc[:max_topic_no, :]
 
577
 
578
  # Forward slashes in the topic names seems to confuse the model
579
  if zero_shot_topics.shape[1] >= 1: # Check if there is at least one column
@@ -646,7 +651,7 @@ def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
646
 
647
  return zero_shot_topics_df
648
 
649
- @spaces.GPU(duration=300)
650
  def extract_topics(in_data_file: GradioFileData,
651
  file_data:pd.DataFrame,
652
  existing_topics_table:pd.DataFrame,
@@ -667,8 +672,6 @@ def extract_topics(in_data_file: GradioFileData,
667
  first_loop_state:bool=False,
668
  whole_conversation_metadata_str:str="",
669
  initial_table_prompt:str=initial_table_prompt,
670
- prompt2:str=prompt2,
671
- prompt3:str=prompt3,
672
  initial_table_system_prompt:str=initial_table_system_prompt,
673
  add_existing_topics_system_prompt:str=add_existing_topics_system_prompt,
674
  add_existing_topics_prompt:str=add_existing_topics_prompt,
@@ -696,6 +699,7 @@ def extract_topics(in_data_file: GradioFileData,
696
  model:object=list(),
697
  tokenizer:object=list(),
698
  assistant_model:object=list(),
 
699
  progress=Progress(track_tqdm=True)):
700
 
701
  '''
@@ -722,8 +726,6 @@ def extract_topics(in_data_file: GradioFileData,
722
  - first_loop_state (bool): A flag indicating the first loop state.
723
  - whole_conversation_metadata_str (str): A string to store whole conversation metadata.
724
  - initial_table_prompt (str): The first prompt for the model.
725
- - prompt2 (str): The second prompt for the model.
726
- - prompt3 (str): The third prompt for the model.
727
  - initial_table_system_prompt (str): The system prompt for the model.
728
  - add_existing_topics_system_prompt (str): The system prompt for the summary part of the model.
729
  - add_existing_topics_prompt (str): The prompt for the model summary.
@@ -748,19 +750,23 @@ def extract_topics(in_data_file: GradioFileData,
748
  - reasoning_suffix (str, optional): The suffix for the reasoning system prompt.
749
  - model: Model object for local inference.
750
  - tokenizer: Tokenizer object for local inference.
 
 
751
  - progress (Progress): A progress tracker.
752
 
753
  '''
754
 
755
  tic = time.perf_counter()
 
756
  google_client = list()
757
  google_config = {}
758
  final_time = 0.0
759
  whole_conversation_metadata = list()
760
  is_error = False
761
  create_revised_general_topics = False
762
- local_model = list()
763
- tokenizer = list()
 
764
  zero_shot_topics_df = pd.DataFrame()
765
  missing_df = pd.DataFrame()
766
  new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Start row of group", "Group" ,"Topic_number", "Summary"])
@@ -795,9 +801,14 @@ def extract_topics(in_data_file: GradioFileData,
795
  file_data, file_name, num_batches = load_in_data_file(in_data_file, chosen_cols, batch_size_default, in_excel_sheets)
796
  except:
797
  # Check if files and text exist
798
- out_message = "Please enter a data file to summarise."
799
  print(out_message)
800
  raise Exception(out_message)
 
 
 
 
 
801
 
802
  model_choice_clean = model_name_map[model_choice]['short_name']
803
  model_source = model_name_map[model_choice]["source"]
@@ -812,8 +823,8 @@ def extract_topics(in_data_file: GradioFileData,
812
  out_file_paths = list()
813
  final_time = 0
814
 
815
- if (model_source == "Local") & (RUN_LOCAL_MODEL == "1") & (not model) & (not tokenizer):
816
- progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
817
  local_model = get_model()
818
  tokenizer = get_tokenizer()
819
  assistant_model = get_assistant_model()
@@ -868,38 +879,36 @@ def extract_topics(in_data_file: GradioFileData,
868
 
869
  # Prepare clients before query
870
  if "Gemini" in model_source:
871
- print("Using Gemini model:", model_choice)
872
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
873
  elif "Azure" in model_source:
874
- print("Using Azure AI Inference model:", model_choice)
875
  # If provided, set env for downstream calls too
876
  if azure_api_key_textbox:
877
  os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
878
  google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
879
  elif "anthropic.claude" in model_choice:
880
- print("Using AWS Bedrock model:", model_choice)
 
881
  else:
882
- print("Using local model:", model_choice)
 
883
 
884
  # Preparing candidate topics if no topics currently exist
885
  if candidate_topics and existing_topic_summary_df.empty:
886
  #progress(0.1, "Creating revised zero shot topics table")
887
 
888
  # 'Zero shot topics' are those supplied by the user
889
- max_topic_no = 120
890
  zero_shot_topics = read_file(candidate_topics.name)
891
- zero_shot_topics = zero_shot_topics.fillna("") # Replace NaN with empty string
892
  zero_shot_topics = zero_shot_topics.astype(str)
893
 
894
- zero_shot_topics_df = generate_zero_shot_topics_df(zero_shot_topics, force_zero_shot_radio, create_revised_general_topics, max_topic_no)
895
-
896
-
897
 
898
  # This part concatenates all zero shot and new topics together, so that for the next prompt the LLM will have the full list available
899
  if not existing_topic_summary_df.empty and force_zero_shot_radio != "Yes":
900
  existing_topic_summary_df = pd.concat([existing_topic_summary_df, zero_shot_topics_df]).drop_duplicates("Subtopic")
901
- else:
902
- existing_topic_summary_df = zero_shot_topics_df
903
 
904
  if candidate_topics and not zero_shot_topics_df.empty:
905
  # If you have already created revised zero shot topics, concat to the current
@@ -1064,18 +1073,12 @@ def extract_topics(in_data_file: GradioFileData,
1064
  unique_topics_markdown="No suggested headings for this summary"
1065
  formatted_initial_table_prompt = structured_summary_prompt.format(response_table=normalised_simple_markdown_table, topics=unique_topics_markdown)
1066
 
1067
- if prompt2: formatted_prompt2 = prompt2.format(response_table=normalised_simple_markdown_table, sentiment_choices=sentiment_prompt)
1068
- else: formatted_prompt2 = prompt2
1069
-
1070
- if prompt3: formatted_prompt3 = prompt3.format(response_table=normalised_simple_markdown_table, sentiment_choices=sentiment_prompt)
1071
- else: formatted_prompt3 = prompt3
1072
-
1073
  #if "Local" in model_source:
1074
  #formatted_initial_table_prompt = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_initial_table_prompt + llama_cpp_suffix
1075
  #formatted_prompt2 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt2 + llama_cpp_suffix
1076
  #formatted_prompt3 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt3 + llama_cpp_suffix
1077
 
1078
- batch_prompts = [formatted_initial_table_prompt, formatted_prompt2, formatted_prompt3][:number_of_prompts_used] # Adjust this list to send fewer requests
1079
 
1080
  if "Local" in model_source and reasoning_suffix: formatted_initial_table_system_prompt = formatted_initial_table_system_prompt + "\n" + reasoning_suffix
1081
 
@@ -1251,6 +1254,7 @@ def extract_topics(in_data_file: GradioFileData,
1251
 
1252
  return unique_table_df_display_table_markdown, existing_topics_table, existing_topic_summary_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, modifiable_topic_summary_df, out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df
1253
 
 
1254
  def wrapper_extract_topics_per_column_value(
1255
  grouping_col: str,
1256
  in_data_file: Any,
@@ -1271,12 +1275,9 @@ def wrapper_extract_topics_per_column_value(
1271
  initial_latest_batch_completed: int = 0,
1272
  initial_time_taken: float = 0,
1273
  initial_table_prompt: str = initial_table_prompt,
1274
- prompt2: str = prompt2,
1275
- prompt3: str = prompt3,
1276
  initial_table_system_prompt: str = initial_table_system_prompt,
1277
  add_existing_topics_system_prompt: str = add_existing_topics_system_prompt,
1278
  add_existing_topics_prompt: str = add_existing_topics_prompt,
1279
-
1280
  number_of_prompts_used: int = 1,
1281
  batch_size: int = 50, # Crucial for calculating num_batches per segment
1282
  context_textbox: str = "",
@@ -1296,8 +1297,10 @@ def wrapper_extract_topics_per_column_value(
1296
  max_time_for_loop: int = max_time_for_loop, # This applies per call to extract_topics
1297
  reasoning_suffix: str = reasoning_suffix,
1298
  CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
1299
- model:object=list(),
1300
- tokenizer:object=list(),
 
 
1301
  progress=Progress(track_tqdm=True) # type: ignore
1302
  ) -> Tuple: # Mimicking the return tuple structure of extract_topics
1303
  """
@@ -1324,8 +1327,6 @@ def wrapper_extract_topics_per_column_value(
1324
  :param initial_latest_batch_completed: The batch number completed in the previous run.
1325
  :param initial_time_taken: Initial time taken for processing.
1326
  :param initial_table_prompt: The initial prompt for table summarization.
1327
- :param prompt2: The second prompt for LLM interaction.
1328
- :param prompt3: The third prompt for LLM interaction.
1329
  :param initial_table_system_prompt: The initial system prompt for table summarization.
1330
  :param add_existing_topics_system_prompt: System prompt for adding existing topics.
1331
  :param add_existing_topics_prompt: Prompt for adding existing topics.
@@ -1350,6 +1351,8 @@ def wrapper_extract_topics_per_column_value(
1350
  :param CHOSEN_LOCAL_MODEL_TYPE: Type of local model chosen.
1351
  :param model: Model object for local inference.
1352
  :param tokenizer: Tokenizer object for local inference.
 
 
1353
  :param progress: Gradio Progress object for tracking progress.
1354
  :return: A tuple containing consolidated results, mimicking the return structure of `extract_topics`.
1355
  """
@@ -1359,6 +1362,23 @@ def wrapper_extract_topics_per_column_value(
1359
  acc_number_of_calls = 0
1360
  out_message = list()
1361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1362
  if grouping_col is None:
1363
  print("No grouping column found")
1364
  file_data["group_col"] = "All"
@@ -1473,8 +1493,6 @@ def wrapper_extract_topics_per_column_value(
1473
  first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
1474
  whole_conversation_metadata_str="", # Fresh for each call
1475
  initial_table_prompt=initial_table_prompt,
1476
- prompt2=prompt2,
1477
- prompt3=prompt3,
1478
  initial_table_system_prompt=initial_table_system_prompt,
1479
  add_existing_topics_system_prompt=add_existing_topics_system_prompt,
1480
  add_existing_topics_prompt=add_existing_topics_prompt,
@@ -1501,6 +1519,8 @@ def wrapper_extract_topics_per_column_value(
1501
  reasoning_suffix=reasoning_suffix,
1502
  model=model,
1503
  tokenizer=tokenizer,
 
 
1504
  progress=progress
1505
  )
1506
 
@@ -1697,3 +1717,340 @@ def modify_existing_output_tables(original_topic_summary_df:pd.DataFrame, modifi
1697
 
1698
 
1699
  return modifiable_topic_summary_df, reference_df, output_file_list, output_file_list, output_file_list, output_file_list, reference_table_file_name, unique_table_file_name, deduplicated_unique_table_markdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  GradioFileData = gr.FileData
16
 
17
  from tools.prompts import initial_table_prompt, prompt2, prompt3, initial_table_system_prompt, add_existing_topics_system_prompt, add_existing_topics_prompt, force_existing_topics_prompt, allow_new_topics_prompt, force_single_topic_prompt, add_existing_topics_assistant_prefill, initial_table_assistant_prefill, structured_summary_prompt
18
+ from tools.helper_functions import read_file, put_columns_in_df, wrap_text, initial_clean, load_in_data_file, load_in_file, create_topic_summary_df_from_reference_table, convert_reference_table_to_pivot_table, get_basic_response_data, clean_column_name, load_in_previous_data_files, create_batch_file_path_details, move_overall_summary_output_files_to_front_page
19
  from tools.llm_funcs import ResponseObject, construct_gemini_generative_model, call_llm_with_markdown_table_checks, create_missing_references_df, calculate_tokens_from_metadata, construct_azure_client, get_model, get_tokenizer, get_assistant_model
20
+ from tools.config import RUN_LOCAL_MODEL, AWS_REGION, MAX_COMMENT_CHARS, MAX_OUTPUT_VALIDATION_ATTEMPTS, MAX_TOKENS, TIMEOUT_WAIT, NUMBER_OF_RETRY_ATTEMPTS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT, DEDUPLICATION_THRESHOLD, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_GROUPS, REASONING_SUFFIX, AZURE_INFERENCE_ENDPOINT, MAX_ROWS, MAXIMUM_ZERO_SHOT_TOPICS, MAX_SPACES_GPU_RUN_TIME
21
  from tools.aws_functions import connect_to_bedrock_runtime
22
+ from tools.dedup_summaries import sample_reference_table_summaries, summarise_output_topics, deduplicate_topics, overall_summary
23
+ from tools.combine_sheets_into_xlsx import collect_output_csvs_and_create_excel_output
24
 
25
  if RUN_LOCAL_MODEL == "1":
26
  from tools.llm_funcs import load_model
 
34
  max_comment_character_length = MAX_COMMENT_CHARS
35
  random_seed = LLM_SEED
36
  reasoning_suffix = REASONING_SUFFIX
37
+ max_rows = MAX_ROWS
38
+ maximum_zero_shot_topics = MAXIMUM_ZERO_SHOT_TOPICS
39
 
40
  ### HELPER FUNCTIONS
41
 
 
542
  def generate_zero_shot_topics_df(zero_shot_topics:pd.DataFrame,
543
  force_zero_shot_radio:str="No",
544
  create_revised_general_topics:bool=False,
545
+ max_topic_no:int=maximum_zero_shot_topics):
546
  """
547
  Preprocesses a DataFrame of zero-shot topics, cleaning and formatting them
548
  for use with a large language model. It handles different column configurations
 
576
 
577
  # Max 120 topics allowed
578
  if zero_shot_topics.shape[0] > max_topic_no:
579
+ out_message = "Maximum " + str(max_topic_no) + " zero-shot topics allowed according to application configuration."
580
+ print(out_message)
581
+ raise Exception(out_message)
582
 
583
  # Forward slashes in the topic names seems to confuse the model
584
  if zero_shot_topics.shape[1] >= 1: # Check if there is at least one column
 
651
 
652
  return zero_shot_topics_df
653
 
654
+
655
  def extract_topics(in_data_file: GradioFileData,
656
  file_data:pd.DataFrame,
657
  existing_topics_table:pd.DataFrame,
 
672
  first_loop_state:bool=False,
673
  whole_conversation_metadata_str:str="",
674
  initial_table_prompt:str=initial_table_prompt,
 
 
675
  initial_table_system_prompt:str=initial_table_system_prompt,
676
  add_existing_topics_system_prompt:str=add_existing_topics_system_prompt,
677
  add_existing_topics_prompt:str=add_existing_topics_prompt,
 
699
  model:object=list(),
700
  tokenizer:object=list(),
701
  assistant_model:object=list(),
702
+ max_rows:int=max_rows,
703
  progress=Progress(track_tqdm=True)):
704
 
705
  '''
 
726
  - first_loop_state (bool): A flag indicating the first loop state.
727
  - whole_conversation_metadata_str (str): A string to store whole conversation metadata.
728
  - initial_table_prompt (str): The first prompt for the model.
 
 
729
  - initial_table_system_prompt (str): The system prompt for the model.
730
  - add_existing_topics_system_prompt (str): The system prompt for the summary part of the model.
731
  - add_existing_topics_prompt (str): The prompt for the model summary.
 
750
  - reasoning_suffix (str, optional): The suffix for the reasoning system prompt.
751
  - model: Model object for local inference.
752
  - tokenizer: Tokenizer object for local inference.
753
+ - assistant_model: Assistant model object for local inference.
754
+ - max_rows: The maximum number of rows to process.
755
  - progress (Progress): A progress tracker.
756
 
757
  '''
758
 
759
  tic = time.perf_counter()
760
+
761
  google_client = list()
762
  google_config = {}
763
  final_time = 0.0
764
  whole_conversation_metadata = list()
765
  is_error = False
766
  create_revised_general_topics = False
767
+ local_model = None
768
+ tokenizer = None
769
+ assistant_model = None
770
  zero_shot_topics_df = pd.DataFrame()
771
  missing_df = pd.DataFrame()
772
  new_reference_df = pd.DataFrame(columns=["Response References", "General topic", "Subtopic", "Sentiment", "Start row of group", "Group" ,"Topic_number", "Summary"])
 
801
  file_data, file_name, num_batches = load_in_data_file(in_data_file, chosen_cols, batch_size_default, in_excel_sheets)
802
  except:
803
  # Check if files and text exist
804
+ out_message = "Please enter a data file to process."
805
  print(out_message)
806
  raise Exception(out_message)
807
+
808
+ if file_data.shape[0] > max_rows:
809
+ out_message = "Your data has more than " + str(max_rows) + " rows, which has been set as the maximum in the application configuration."
810
+ print(out_message)
811
+ raise Exception(out_message)
812
 
813
  model_choice_clean = model_name_map[model_choice]['short_name']
814
  model_source = model_name_map[model_choice]["source"]
 
823
  out_file_paths = list()
824
  final_time = 0
825
 
826
+ if (model_source == "Local") & (RUN_LOCAL_MODEL == "1") & (not model):
827
+ progress(0.1, f"Using local model: {model_choice_clean}")
828
  local_model = get_model()
829
  tokenizer = get_tokenizer()
830
  assistant_model = get_assistant_model()
 
879
 
880
  # Prepare clients before query
881
  if "Gemini" in model_source:
882
+ #print("Using Gemini model:", model_choice)
883
  google_client, google_config = construct_gemini_generative_model(in_api_key=in_api_key, temperature=temperature, model_choice=model_choice, system_prompt=formatted_system_prompt, max_tokens=max_tokens)
884
  elif "Azure" in model_source:
885
+ #print("Using Azure AI Inference model:", model_choice)
886
  # If provided, set env for downstream calls too
887
  if azure_api_key_textbox:
888
  os.environ["AZURE_INFERENCE_CREDENTIAL"] = azure_api_key_textbox
889
  google_client, google_config = construct_azure_client(in_api_key=azure_api_key_textbox, endpoint=AZURE_INFERENCE_ENDPOINT)
890
  elif "anthropic.claude" in model_choice:
891
+ #print("Using AWS Bedrock model:", model_choice)
892
+ pass
893
  else:
894
+ #print("Using local model:", model_choice)
895
+ pass
896
 
897
  # Preparing candidate topics if no topics currently exist
898
  if candidate_topics and existing_topic_summary_df.empty:
899
  #progress(0.1, "Creating revised zero shot topics table")
900
 
901
  # 'Zero shot topics' are those supplied by the user
 
902
  zero_shot_topics = read_file(candidate_topics.name)
903
+ zero_shot_topics = zero_shot_topics.fillna("") # Replace NaN with empty string
904
  zero_shot_topics = zero_shot_topics.astype(str)
905
 
906
+ zero_shot_topics_df = generate_zero_shot_topics_df(zero_shot_topics, force_zero_shot_radio, create_revised_general_topics)
 
 
907
 
908
  # This part concatenates all zero shot and new topics together, so that for the next prompt the LLM will have the full list available
909
  if not existing_topic_summary_df.empty and force_zero_shot_radio != "Yes":
910
  existing_topic_summary_df = pd.concat([existing_topic_summary_df, zero_shot_topics_df]).drop_duplicates("Subtopic")
911
+ else: existing_topic_summary_df = zero_shot_topics_df
 
912
 
913
  if candidate_topics and not zero_shot_topics_df.empty:
914
  # If you have already created revised zero shot topics, concat to the current
 
1073
  unique_topics_markdown="No suggested headings for this summary"
1074
  formatted_initial_table_prompt = structured_summary_prompt.format(response_table=normalised_simple_markdown_table, topics=unique_topics_markdown)
1075
 
 
 
 
 
 
 
1076
  #if "Local" in model_source:
1077
  #formatted_initial_table_prompt = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_initial_table_prompt + llama_cpp_suffix
1078
  #formatted_prompt2 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt2 + llama_cpp_suffix
1079
  #formatted_prompt3 = llama_cpp_prefix + formatted_initial_table_system_prompt + "\n" + formatted_prompt3 + llama_cpp_suffix
1080
 
1081
+ batch_prompts = [formatted_initial_table_prompt]
1082
 
1083
  if "Local" in model_source and reasoning_suffix: formatted_initial_table_system_prompt = formatted_initial_table_system_prompt + "\n" + reasoning_suffix
1084
 
 
1254
 
1255
  return unique_table_df_display_table_markdown, existing_topics_table, existing_topic_summary_df, existing_reference_df, out_file_paths, out_file_paths, latest_batch_completed, log_files_output_paths, log_files_output_paths, whole_conversation_metadata_str, final_time, out_file_paths, out_file_paths, modifiable_topic_summary_df, out_file_paths, join_file_paths, existing_reference_df_pivot, missing_df
1256
 
1257
+ @spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
1258
  def wrapper_extract_topics_per_column_value(
1259
  grouping_col: str,
1260
  in_data_file: Any,
 
1275
  initial_latest_batch_completed: int = 0,
1276
  initial_time_taken: float = 0,
1277
  initial_table_prompt: str = initial_table_prompt,
 
 
1278
  initial_table_system_prompt: str = initial_table_system_prompt,
1279
  add_existing_topics_system_prompt: str = add_existing_topics_system_prompt,
1280
  add_existing_topics_prompt: str = add_existing_topics_prompt,
 
1281
  number_of_prompts_used: int = 1,
1282
  batch_size: int = 50, # Crucial for calculating num_batches per segment
1283
  context_textbox: str = "",
 
1297
  max_time_for_loop: int = max_time_for_loop, # This applies per call to extract_topics
1298
  reasoning_suffix: str = reasoning_suffix,
1299
  CHOSEN_LOCAL_MODEL_TYPE: str = CHOSEN_LOCAL_MODEL_TYPE,
1300
+ model:object=None,
1301
+ tokenizer:object=None,
1302
+ assistant_model:object=None,
1303
+ max_rows:int=max_rows,
1304
  progress=Progress(track_tqdm=True) # type: ignore
1305
  ) -> Tuple: # Mimicking the return tuple structure of extract_topics
1306
  """
 
1327
  :param initial_latest_batch_completed: The batch number completed in the previous run.
1328
  :param initial_time_taken: Initial time taken for processing.
1329
  :param initial_table_prompt: The initial prompt for table summarization.
 
 
1330
  :param initial_table_system_prompt: The initial system prompt for table summarization.
1331
  :param add_existing_topics_system_prompt: System prompt for adding existing topics.
1332
  :param add_existing_topics_prompt: Prompt for adding existing topics.
 
1351
  :param CHOSEN_LOCAL_MODEL_TYPE: Type of local model chosen.
1352
  :param model: Model object for local inference.
1353
  :param tokenizer: Tokenizer object for local inference.
1354
+ :param assistant_model: Assistant model object for local inference.
1355
+ :param max_rows: The maximum number of rows to process.
1356
  :param progress: Gradio Progress object for tracking progress.
1357
  :return: A tuple containing consolidated results, mimicking the return structure of `extract_topics`.
1358
  """
 
1362
  acc_number_of_calls = 0
1363
  out_message = list()
1364
 
1365
+ # If you have a file input but no file data it hasn't yet been loaded. Load it here.
1366
+ if file_data.empty:
1367
+ print("No data table found, loading from file")
1368
+ try:
1369
+ in_colnames_drop, in_excel_sheets, file_name = put_columns_in_df(in_data_file)
1370
+ file_data, file_name, num_batches = load_in_data_file(in_data_file, chosen_cols, batch_size_default, in_excel_sheets)
1371
+ except:
1372
+ # Check if files and text exist
1373
+ out_message = "Please enter a data file to process."
1374
+ print(out_message)
1375
+ raise Exception(out_message)
1376
+
1377
+ if file_data.shape[0] > max_rows:
1378
+ out_message = "Your data has more than " + str(max_rows) + " rows, which has been set as the maximum in the application configuration."
1379
+ print(out_message)
1380
+ raise Exception(out_message)
1381
+
1382
  if grouping_col is None:
1383
  print("No grouping column found")
1384
  file_data["group_col"] = "All"
 
1493
  first_loop_state=current_first_loop_state, # True only for the very first iteration of wrapper
1494
  whole_conversation_metadata_str="", # Fresh for each call
1495
  initial_table_prompt=initial_table_prompt,
 
 
1496
  initial_table_system_prompt=initial_table_system_prompt,
1497
  add_existing_topics_system_prompt=add_existing_topics_system_prompt,
1498
  add_existing_topics_prompt=add_existing_topics_prompt,
 
1519
  reasoning_suffix=reasoning_suffix,
1520
  model=model,
1521
  tokenizer=tokenizer,
1522
+ assistant_model=assistant_model,
1523
+ max_rows=max_rows,
1524
  progress=progress
1525
  )
1526
 
 
1717
 
1718
 
1719
  return modifiable_topic_summary_df, reference_df, output_file_list, output_file_list, output_file_list, output_file_list, reference_table_file_name, unique_table_file_name, deduplicated_unique_table_markdown
1720
+
1721
+ @spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
1722
+ def all_in_one_pipeline(
1723
+ grouping_col: str,
1724
+ in_data_files: List[str],
1725
+ file_data: pd.DataFrame,
1726
+ existing_topics_table: pd.DataFrame,
1727
+ existing_reference_df: pd.DataFrame,
1728
+ existing_topic_summary_df: pd.DataFrame,
1729
+ unique_table_df_display_table_markdown: str,
1730
+ original_file_name: str,
1731
+ total_number_of_batches: int,
1732
+ in_api_key: str,
1733
+ temperature: float,
1734
+ chosen_cols: List[str],
1735
+ model_choice: str,
1736
+ candidate_topics: GradioFileData,
1737
+ first_loop_state: bool,
1738
+ conversation_metadata_text: str,
1739
+ latest_batch_completed: int,
1740
+ time_taken_so_far: float,
1741
+ initial_table_prompt_text: str,
1742
+ initial_table_system_prompt_text: str,
1743
+ add_existing_topics_system_prompt_text: str,
1744
+ add_existing_topics_prompt_text: str,
1745
+ number_of_prompts_used: int,
1746
+ batch_size: int,
1747
+ context_text: str,
1748
+ sentiment_choice: str,
1749
+ force_zero_shot_choice: str,
1750
+ in_excel_sheets: List[str],
1751
+ force_single_topic_choice: str,
1752
+ produce_structures_summary_choice: str,
1753
+ aws_access_key_text: str,
1754
+ aws_secret_key_text: str,
1755
+ hf_api_key_text: str,
1756
+ azure_api_key_text: str,
1757
+ output_folder: str = OUTPUT_FOLDER,
1758
+ merge_sentiment: str = "No",
1759
+ merge_general_topics: str = "Yes",
1760
+ score_threshold: int = 90,
1761
+ summarise_format: str = "",
1762
+ random_seed: int = 0,
1763
+ log_files_output_list_state: List[str] = list(),
1764
+ model_name_map_state: dict = model_name_map,
1765
+ model: object = None,
1766
+ tokenizer: object = None,
1767
+ assistant_model: object = None,
1768
+ usage_logs_location: str = "",
1769
+ max_rows: int = max_rows,
1770
+ progress=Progress(track_tqdm=True)
1771
+ ):
1772
+ """
1773
+ Orchestrates the full All-in-one flow: extract → deduplicate → summarise → overall summary → Excel export.
1774
+
1775
+ Returns a large tuple matching the UI components updated during the original chained flow.
1776
+ """
1777
+
1778
+ # Load local model if it's not already loaded
1779
+ if (model_name_map_state[model_choice]["source"] == "Local") & (RUN_LOCAL_MODEL == "1") & (not model):
1780
+ model = get_model()
1781
+ tokenizer = get_tokenizer()
1782
+ assistant_model = get_assistant_model()
1783
+
1784
+ total_input_tokens = 0
1785
+ total_output_tokens = 0
1786
+ total_number_of_calls = 0
1787
+ total_time_taken = 0
1788
+ out_message = list()
1789
+
1790
+ # 1) Extract topics (group-aware)
1791
+ (
1792
+ display_markdown,
1793
+ out_topics_table,
1794
+ out_topic_summary_df,
1795
+ out_reference_df,
1796
+ out_file_paths_1,
1797
+ _out_file_paths_dup,
1798
+ out_latest_batch_completed,
1799
+ out_log_files,
1800
+ _out_log_files_dup,
1801
+ out_conversation_metadata,
1802
+ out_time_taken,
1803
+ out_file_paths_2,
1804
+ _out_file_paths_3,
1805
+ out_gradio_df,
1806
+ out_file_paths_4,
1807
+ out_join_files,
1808
+ out_missing_df,
1809
+ out_input_tokens,
1810
+ out_output_tokens,
1811
+ out_number_of_calls,
1812
+ out_message_text
1813
+ ) = wrapper_extract_topics_per_column_value(
1814
+ grouping_col=grouping_col,
1815
+ in_data_file=in_data_files,
1816
+ file_data=file_data,
1817
+ initial_existing_topics_table=existing_topics_table,
1818
+ initial_existing_reference_df=existing_reference_df,
1819
+ initial_existing_topic_summary_df=existing_topic_summary_df,
1820
+ initial_unique_table_df_display_table_markdown=unique_table_df_display_table_markdown,
1821
+ original_file_name=original_file_name,
1822
+ total_number_of_batches=total_number_of_batches,
1823
+ in_api_key=in_api_key,
1824
+ temperature=temperature,
1825
+ chosen_cols=chosen_cols,
1826
+ model_choice=model_choice,
1827
+ candidate_topics=candidate_topics,
1828
+ initial_first_loop_state=first_loop_state,
1829
+ initial_whole_conversation_metadata_str=conversation_metadata_text,
1830
+ initial_latest_batch_completed=latest_batch_completed,
1831
+ initial_time_taken=time_taken_so_far,
1832
+ initial_table_prompt=initial_table_prompt_text,
1833
+ initial_table_system_prompt=initial_table_system_prompt_text,
1834
+ add_existing_topics_system_prompt=add_existing_topics_system_prompt_text,
1835
+ add_existing_topics_prompt=add_existing_topics_prompt_text,
1836
+ number_of_prompts_used=number_of_prompts_used,
1837
+ batch_size=batch_size,
1838
+ context_textbox=context_text,
1839
+ sentiment_checkbox=sentiment_choice,
1840
+ force_zero_shot_radio=force_zero_shot_choice,
1841
+ in_excel_sheets=in_excel_sheets,
1842
+ force_single_topic_radio=force_single_topic_choice,
1843
+ produce_structures_summary_radio=produce_structures_summary_choice,
1844
+ aws_access_key_textbox=aws_access_key_text,
1845
+ aws_secret_key_textbox=aws_secret_key_text,
1846
+ hf_api_key_textbox=hf_api_key_text,
1847
+ azure_api_key_textbox=azure_api_key_text,
1848
+ output_folder=output_folder,
1849
+ model_name_map=model_name_map_state,
1850
+ model=model,
1851
+ tokenizer=tokenizer,
1852
+ assistant_model=assistant_model,
1853
+ max_rows=max_rows
1854
+ )
1855
+
1856
+ total_input_tokens += out_input_tokens
1857
+ total_output_tokens += out_output_tokens
1858
+ total_number_of_calls += out_number_of_calls
1859
+ total_time_taken += out_time_taken
1860
+ out_message.append(out_message_text)
1861
+
1862
+ # Prepare outputs after extraction, matching wrapper outputs
1863
+ topic_extraction_output_files = out_file_paths_1
1864
+ text_output_file_list_state = out_file_paths_1
1865
+ log_files_output_list_state = out_log_files
1866
+
1867
+ # 2) Deduplication
1868
+ (
1869
+ ref_df_loaded,
1870
+ unique_df_loaded,
1871
+ latest_batch_completed_no_loop,
1872
+ deduplication_input_files_status,
1873
+ working_data_file_name_textbox,
1874
+ unique_topics_table_file_name_textbox
1875
+ ) = load_in_previous_data_files(out_file_paths_1)
1876
+
1877
+ ref_df_after_dedup, unique_df_after_dedup, summarisation_input_files, log_files_output_dedup, summarised_output_markdown = deduplicate_topics(
1878
+ reference_df=ref_df_loaded if not ref_df_loaded.empty else out_reference_df,
1879
+ topic_summary_df=unique_df_loaded if not unique_df_loaded.empty else out_topic_summary_df,
1880
+ reference_table_file_name=working_data_file_name_textbox,
1881
+ unique_topics_table_file_name=unique_topics_table_file_name_textbox,
1882
+ in_excel_sheets=in_excel_sheets,
1883
+ merge_sentiment=merge_sentiment,
1884
+ merge_general_topics=merge_general_topics,
1885
+ score_threshold=score_threshold,
1886
+ in_data_files=in_data_files,
1887
+ chosen_cols=chosen_cols,
1888
+ output_folder=output_folder
1889
+ )
1890
+
1891
+ # 3) Summarisation
1892
+ (
1893
+ ref_df_loaded_2,
1894
+ unique_df_loaded_2,
1895
+ _latest_batch_completed_no_loop_2,
1896
+ _deduplication_input_files_status_2,
1897
+ _working_name_2,
1898
+ _unique_name_2
1899
+ ) = load_in_previous_data_files(summarisation_input_files)
1900
+
1901
+ summary_reference_table_sample_state, summarised_references_markdown = sample_reference_table_summaries(ref_df_after_dedup, random_seed)
1902
+
1903
+ (
1904
+ _summary_reference_table_sample_state,
1905
+ master_unique_topics_df_revised_summaries_state,
1906
+ master_reference_df_revised_summaries_state,
1907
+ summary_output_files,
1908
+ summarised_outputs_list,
1909
+ latest_summary_completed_num,
1910
+ conversation_metadata_text_updated,
1911
+ display_markdown_updated,
1912
+ log_files_output_after_sum,
1913
+ overall_summarisation_input_files,
1914
+ input_tokens_num,
1915
+ output_tokens_num,
1916
+ number_of_calls_num,
1917
+ estimated_time_taken_number,
1918
+ output_messages_textbox
1919
+ ) = summarise_output_topics(
1920
+ sampled_reference_table_df=summary_reference_table_sample_state,
1921
+ topic_summary_df=unique_df_after_dedup,
1922
+ reference_table_df=ref_df_after_dedup,
1923
+ model_choice=model_choice,
1924
+ in_api_key=in_api_key,
1925
+ temperature=temperature,
1926
+ reference_data_file_name=working_data_file_name_textbox,
1927
+ summarised_outputs=list(),
1928
+ latest_summary_completed=0,
1929
+ out_metadata_str=out_conversation_metadata,
1930
+ in_data_files=in_data_files,
1931
+ in_excel_sheets=in_excel_sheets,
1932
+ chosen_cols=chosen_cols,
1933
+ log_output_files=log_files_output_list_state,
1934
+ summarise_format_radio=summarise_format,
1935
+ output_folder=output_folder,
1936
+ context_textbox=context_text,
1937
+ aws_access_key_textbox=aws_access_key_text,
1938
+ aws_secret_key_textbox=aws_secret_key_text,
1939
+ model_name_map=model_name_map_state,
1940
+ hf_api_key_textbox=hf_api_key_text,
1941
+ local_model=model,
1942
+ tokenizer=tokenizer,
1943
+ assistant_model=assistant_model
1944
+ )
1945
+
1946
+ total_input_tokens += input_tokens_num
1947
+ total_output_tokens += output_tokens_num
1948
+ total_number_of_calls += number_of_calls_num
1949
+ total_time_taken += estimated_time_taken_number
1950
+ out_message.append(output_messages_textbox)
1951
+
1952
+ # 4) Overall summary
1953
+ (
1954
+ _ref_df_loaded_3,
1955
+ _unique_df_loaded_3,
1956
+ _latest_batch_completed_no_loop_3,
1957
+ _deduplication_input_files_status_3,
1958
+ _working_name_3,
1959
+ _unique_name_3
1960
+ ) = load_in_previous_data_files(overall_summarisation_input_files)
1961
+
1962
+ (
1963
+ overall_summary_output_files,
1964
+ overall_summarised_output_markdown,
1965
+ summarised_output_df,
1966
+ conversation_metadata_textbox,
1967
+ input_tokens_num,
1968
+ output_tokens_num,
1969
+ number_of_calls_num,
1970
+ estimated_time_taken_number,
1971
+ output_messages_textbox
1972
+ ) = overall_summary(
1973
+ topic_summary_df=master_unique_topics_df_revised_summaries_state,
1974
+ model_choice=model_choice,
1975
+ in_api_key=in_api_key,
1976
+ temperature=temperature,
1977
+ reference_data_file_name=working_data_file_name_textbox,
1978
+ output_folder=output_folder,
1979
+ chosen_cols=chosen_cols,
1980
+ context_textbox=context_text,
1981
+ aws_access_key_textbox=aws_access_key_text,
1982
+ aws_secret_key_textbox=aws_secret_key_text,
1983
+ model_name_map=model_name_map_state,
1984
+ hf_api_key_textbox=hf_api_key_text,
1985
+ local_model=model,
1986
+ tokenizer=tokenizer,
1987
+ assistant_model=assistant_model
1988
+ )
1989
+
1990
+ total_input_tokens += input_tokens_num
1991
+ total_output_tokens += output_tokens_num
1992
+ total_number_of_calls += number_of_calls_num
1993
+ total_time_taken += estimated_time_taken_number
1994
+ out_message.append(output_messages_textbox)
1995
+
1996
+ out_message = '\n'.join(out_message)
1997
+ out_message = out_message + "\n" + f"Overall time for all processes: {total_time_taken:.2f}s"
1998
+ print(out_message)
1999
+
2000
+ # 5) Excel export and move xlsx to front page
2001
+ # overall_summary_output_files_xlsx, summary_xlsx_output_files_list = collect_output_csvs_and_create_excel_output(
2002
+ # in_data_files=in_data_files,
2003
+ # chosen_cols=chosen_cols,
2004
+ # reference_data_file_name_textbox=original_file_name,
2005
+ # in_group_col=grouping_col,
2006
+ # model_choice=model_choice,
2007
+ # master_reference_df_state=master_reference_df_revised_summaries_state,
2008
+ # master_unique_topics_df_state=master_unique_topics_df_revised_summaries_state,
2009
+ # summarised_output_df=summarised_output_df,
2010
+ # missing_df_state=out_missing_df,
2011
+ # excel_sheets=in_excel_sheets,
2012
+ # usage_logs_location=usage_logs_location,
2013
+ # model_name_map=model_name_map_state,
2014
+ # output_folder=output_folder
2015
+ # )
2016
+
2017
+ # topic_extraction_output_files_xlsx = move_overall_summary_output_files_to_front_page(summary_xlsx_output_files_list)
2018
+
2019
+ # Map to the UI outputs list expected by the new single-call wiring
2020
+ return (
2021
+ display_markdown_updated if display_markdown_updated else display_markdown,
2022
+ out_topics_table,
2023
+ unique_df_after_dedup,
2024
+ ref_df_after_dedup,
2025
+ topic_extraction_output_files,
2026
+ text_output_file_list_state,
2027
+ out_latest_batch_completed,
2028
+ log_files_output_after_sum if log_files_output_after_sum else out_log_files,
2029
+ log_files_output_list_state,
2030
+ conversation_metadata_text_updated if conversation_metadata_text_updated else out_conversation_metadata,
2031
+ total_time_taken,
2032
+ out_file_paths_1,
2033
+ summarisation_input_files,
2034
+ out_gradio_df,
2035
+ list(), # modification_input_files placeholder
2036
+ out_join_files,
2037
+ out_missing_df,
2038
+ total_input_tokens,
2039
+ total_output_tokens,
2040
+ total_number_of_calls,
2041
+ out_message,
2042
+ summary_reference_table_sample_state,
2043
+ summarised_references_markdown,
2044
+ master_unique_topics_df_revised_summaries_state,
2045
+ master_reference_df_revised_summaries_state,
2046
+ summary_output_files,
2047
+ summarised_outputs_list,
2048
+ latest_summary_completed_num,
2049
+ overall_summarisation_input_files,
2050
+ overall_summary_output_files,
2051
+ overall_summarised_output_markdown,
2052
+ summarised_output_df#,
2053
+ # overall_summary_output_files_xlsx,
2054
+ # summary_xlsx_output_files_list,
2055
+ # topic_extraction_output_files_xlsx
2056
+ )
tools/llm_funcs.py CHANGED
@@ -214,18 +214,11 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
214
  - tokenizer (list/transformers tokenizer): An empty list (tokenizer is not used with Llama.cpp directly in this setup), or a transformers tokenizer.
215
  - assistant_model (transformers model): The assistant model for speculative decoding (if USE_SPECULATIVE_DECODING is True).
216
  '''
217
- print("Loading model ", local_model_type)
218
-
219
- #print("model_path:", model_path)
220
 
221
- if model is None:
222
- model = list()
223
- else:
224
- return model, tokenizer
225
- if tokenizer is None:
226
- tokenizer = list()
227
- else:
228
- return model, tokenizer
229
 
230
  # Verify the device and cuda settings
231
  # Check if CUDA is enabled
@@ -427,7 +420,7 @@ def load_model(local_model_type:str=CHOSEN_LOCAL_MODEL_TYPE,
427
  def get_model():
428
  """Get the globally loaded model. Load it if not already loaded."""
429
  global _model, _tokenizer, _assistant_model
430
- if _model is None and LOAD_LOCAL_MODEL_AT_START == "True":
431
  _model, _tokenizer, _assistant_model = load_model(
432
  local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
433
  gpu_layers=gpu_layers,
@@ -450,7 +443,7 @@ def get_model():
450
  def get_tokenizer():
451
  """Get the globally loaded tokenizer. Load it if not already loaded."""
452
  global _model, _tokenizer, _assistant_model
453
- if _tokenizer is None and LOAD_LOCAL_MODEL_AT_START == "True":
454
  _model, _tokenizer, _assistant_model = load_model(
455
  local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
456
  gpu_layers=gpu_layers,
@@ -473,7 +466,7 @@ def get_tokenizer():
473
  def get_assistant_model():
474
  """Get the globally loaded assistant model. Load it if not already loaded."""
475
  global _model, _tokenizer, _assistant_model
476
- if _assistant_model is None and LOAD_LOCAL_MODEL_AT_START == "True":
477
  _model, _tokenizer, _assistant_model = load_model(
478
  local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
479
  gpu_layers=gpu_layers,
 
214
  - tokenizer (list/transformers tokenizer): An empty list (tokenizer is not used with Llama.cpp directly in this setup), or a transformers tokenizer.
215
  - assistant_model (transformers model): The assistant model for speculative decoding (if USE_SPECULATIVE_DECODING is True).
216
  '''
217
+
218
+ if model:
219
+ return model, tokenizer, assistant_model
220
 
221
+ print("Loading model:", local_model_type)
 
 
 
 
 
 
 
222
 
223
  # Verify the device and cuda settings
224
  # Check if CUDA is enabled
 
420
  def get_model():
421
  """Get the globally loaded model. Load it if not already loaded."""
422
  global _model, _tokenizer, _assistant_model
423
+ if _model is None:
424
  _model, _tokenizer, _assistant_model = load_model(
425
  local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
426
  gpu_layers=gpu_layers,
 
443
  def get_tokenizer():
444
  """Get the globally loaded tokenizer. Load it if not already loaded."""
445
  global _model, _tokenizer, _assistant_model
446
+ if _tokenizer is None:
447
  _model, _tokenizer, _assistant_model = load_model(
448
  local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
449
  gpu_layers=gpu_layers,
 
466
  def get_assistant_model():
467
  """Get the globally loaded assistant model. Load it if not already loaded."""
468
  global _model, _tokenizer, _assistant_model
469
+ if _assistant_model is None:
470
  _model, _tokenizer, _assistant_model = load_model(
471
  local_model_type=CHOSEN_LOCAL_MODEL_TYPE,
472
  gpu_layers=gpu_layers,
tools/verify_titles.py CHANGED
@@ -12,7 +12,7 @@ GradioFileData = gr.FileData
12
 
13
  from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt,add_existing_topics_system_prompt, add_existing_topics_prompt, initial_table_assistant_prefill, add_existing_topics_assistant_prefill
14
  from tools.helper_functions import put_columns_in_df, wrap_text, clean_column_name, create_batch_file_path_details
15
- from tools.llm_funcs import load_model, construct_gemini_generative_model, call_llm_with_markdown_table_checks, get_model, get_tokenizer
16
  from tools.llm_api_call import load_in_data_file, get_basic_response_data, data_file_to_markdown_table, convert_response_text_to_dataframe, ResponseObject
17
  from tools.config import MAX_OUTPUT_VALIDATION_ATTEMPTS, RUN_LOCAL_MODEL, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_TOKENS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT
18
  from tools.aws_functions import connect_to_bedrock_runtime
@@ -225,8 +225,6 @@ def verify_titles(in_data_file,
225
  first_loop_state:bool=False,
226
  whole_conversation_metadata_str:str="",
227
  initial_table_prompt:str=initial_table_prompt,
228
- prompt2:str=prompt2,
229
- prompt3:str=prompt3,
230
  system_prompt:str=system_prompt,
231
  add_existing_topics_system_prompt:str=add_existing_topics_system_prompt,
232
  add_existing_topics_prompt:str=add_existing_topics_prompt,
@@ -242,7 +240,10 @@ def verify_titles(in_data_file,
242
  in_excel_sheets:List[str] = list(),
243
  output_folder:str=OUTPUT_FOLDER,
244
  max_tokens:int=max_tokens,
245
- model_name_map:dict=model_name_map,
 
 
 
246
  max_time_for_loop:int=max_time_for_loop,
247
  progress=Progress(track_tqdm=True)):
248
 
@@ -270,8 +271,6 @@ def verify_titles(in_data_file,
270
  - first_loop_state (bool): A flag indicating the first loop state.
271
  - whole_conversation_metadata_str (str): A string to store whole conversation metadata.
272
  - initial_table_prompt (str): The first prompt for the model.
273
- - prompt2 (str): The second prompt for the model.
274
- - prompt3 (str): The third prompt for the model.
275
  - system_prompt (str): The system prompt for the model.
276
  - add_existing_topics_system_prompt (str): The system prompt for the summary part of the model.
277
  - add_existing_topics_prompt (str): The prompt for the model summary.
@@ -288,6 +287,9 @@ def verify_titles(in_data_file,
288
  - output_folder (str): The output folder where files will be saved.
289
  - max_tokens (int): The maximum number of tokens for the model.
290
  - model_name_map (dict, optional): A dictionary mapping full model name to shortened.
 
 
 
291
  - max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
292
  - progress (Progress): A progress tracker.
293
  '''
@@ -299,8 +301,9 @@ def verify_titles(in_data_file,
299
  whole_conversation_metadata = list()
300
  is_error = False
301
  create_revised_general_topics = False
302
- local_model = list()
303
- tokenizer = list()
 
304
  zero_shot_topics_df = pd.DataFrame()
305
  #llama_system_prefix = "<|start_header_id|>system<|end_header_id|>\n" #"<start_of_turn>user\n"
306
  #llama_system_suffix = "<|eot_id|>" #"<end_of_turn>\n<start_of_turn>model\n"
@@ -343,11 +346,11 @@ def verify_titles(in_data_file,
343
  out_file_paths = list()
344
  #print("model_choice_clean:", model_choice_clean)
345
 
346
- if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1"):
347
  progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
348
  local_model = get_model()
349
  tokenizer = get_tokenizer()
350
- #print("Local model loaded:", local_model)
351
 
352
  if num_batches > 0:
353
  progress_measure = round(latest_batch_completed / num_batches, 1)
 
12
 
13
  from tools.prompts import initial_table_prompt, prompt2, prompt3, system_prompt,add_existing_topics_system_prompt, add_existing_topics_prompt, initial_table_assistant_prefill, add_existing_topics_assistant_prefill
14
  from tools.helper_functions import put_columns_in_df, wrap_text, clean_column_name, create_batch_file_path_details
15
+ from tools.llm_funcs import load_model, construct_gemini_generative_model, call_llm_with_markdown_table_checks, get_model, get_tokenizer, get_assistant_model
16
  from tools.llm_api_call import load_in_data_file, get_basic_response_data, data_file_to_markdown_table, convert_response_text_to_dataframe, ResponseObject
17
  from tools.config import MAX_OUTPUT_VALIDATION_ATTEMPTS, RUN_LOCAL_MODEL, model_name_map, OUTPUT_FOLDER, CHOSEN_LOCAL_MODEL_TYPE, LOCAL_REPO_ID, LOCAL_MODEL_FILE, LOCAL_MODEL_FOLDER, LLM_SEED, MAX_TOKENS, MAX_TIME_FOR_LOOP, BATCH_SIZE_DEFAULT
18
  from tools.aws_functions import connect_to_bedrock_runtime
 
225
  first_loop_state:bool=False,
226
  whole_conversation_metadata_str:str="",
227
  initial_table_prompt:str=initial_table_prompt,
 
 
228
  system_prompt:str=system_prompt,
229
  add_existing_topics_system_prompt:str=add_existing_topics_system_prompt,
230
  add_existing_topics_prompt:str=add_existing_topics_prompt,
 
240
  in_excel_sheets:List[str] = list(),
241
  output_folder:str=OUTPUT_FOLDER,
242
  max_tokens:int=max_tokens,
243
+ model_name_map:dict=model_name_map,
244
+ local_model:object=None,
245
+ tokenizer:object=None,
246
+ assistant_model:object=None,
247
  max_time_for_loop:int=max_time_for_loop,
248
  progress=Progress(track_tqdm=True)):
249
 
 
271
  - first_loop_state (bool): A flag indicating the first loop state.
272
  - whole_conversation_metadata_str (str): A string to store whole conversation metadata.
273
  - initial_table_prompt (str): The first prompt for the model.
 
 
274
  - system_prompt (str): The system prompt for the model.
275
  - add_existing_topics_system_prompt (str): The system prompt for the summary part of the model.
276
  - add_existing_topics_prompt (str): The prompt for the model summary.
 
287
  - output_folder (str): The output folder where files will be saved.
288
  - max_tokens (int): The maximum number of tokens for the model.
289
  - model_name_map (dict, optional): A dictionary mapping full model name to shortened.
290
+ - local_model (object, optional): Local model object if using local inference. Defaults to None.
291
+ - tokenizer (object, optional): Tokenizer object if using local inference. Defaults to None.
292
+ - assistant_model (object, optional): Assistant model object if using local inference. Defaults to None.
293
  - max_time_for_loop (int, optional): The number of seconds maximum that the function should run for before breaking (to run again, this is to avoid timeouts with some AWS services if deployed there).
294
  - progress (Progress): A progress tracker.
295
  '''
 
301
  whole_conversation_metadata = list()
302
  is_error = False
303
  create_revised_general_topics = False
304
+ local_model = None
305
+ tokenizer = None
306
+ assistant_model = None
307
  zero_shot_topics_df = pd.DataFrame()
308
  #llama_system_prefix = "<|start_header_id|>system<|end_header_id|>\n" #"<start_of_turn>user\n"
309
  #llama_system_suffix = "<|eot_id|>" #"<end_of_turn>\n<start_of_turn>model\n"
 
346
  out_file_paths = list()
347
  #print("model_choice_clean:", model_choice_clean)
348
 
349
+ if (model_choice == CHOSEN_LOCAL_MODEL_TYPE) & (RUN_LOCAL_MODEL == "1") & (not local_model):
350
  progress(0.1, f"Using global model: {CHOSEN_LOCAL_MODEL_TYPE}")
351
  local_model = get_model()
352
  tokenizer = get_tokenizer()
353
+ assistant_model = get_assistant_model()
354
 
355
  if num_batches > 0:
356
  progress_measure = round(latest_batch_completed / num_batches, 1)