|  |  | 
					
						
						|  | """ | 
					
						
						|  | Camie-Tagger-V2 Application | 
					
						
						|  | A Streamlit web app for tagging images using an AI model. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import streamlit as st | 
					
						
						|  | import os | 
					
						
						|  | import sys | 
					
						
						|  | import traceback | 
					
						
						|  | import tempfile | 
					
						
						|  | import time | 
					
						
						|  | import platform | 
					
						
						|  | import subprocess | 
					
						
						|  | import webbrowser | 
					
						
						|  | import glob | 
					
						
						|  | import numpy as np | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  | import io | 
					
						
						|  | import base64 | 
					
						
						|  | import json | 
					
						
						|  | from matplotlib.colors import LinearSegmentedColormap | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from utils.image_processing import process_image, batch_process_images | 
					
						
						|  | from utils.file_utils import save_tags_to_file, get_default_save_locations | 
					
						
						|  | from utils.ui_components import display_progress_bar, show_example_images, display_batch_results | 
					
						
						|  | from utils.onnx_processing import batch_process_images_onnx | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | 
					
						
						|  | print(f"Using model directory: {MODEL_DIR}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | threshold_profile_descriptions = { | 
					
						
						|  | "Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.", | 
					
						
						|  | "Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.", | 
					
						
						|  | "Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.", | 
					
						
						|  | "Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.", | 
					
						
						|  | "Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results." | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | threshold_profile_explanations = { | 
					
						
						|  | "Micro Optimized": """ | 
					
						
						|  | ### Micro Optimized Profile | 
					
						
						|  |  | 
					
						
						|  | **Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions. | 
					
						
						|  |  | 
					
						
						|  | **When to use**: When you want the best overall accuracy, especially for common tags and dominant categories. | 
					
						
						|  |  | 
					
						
						|  | **Effects**: | 
					
						
						|  | - Optimizes performance for the most frequent tags | 
					
						
						|  | - Gives more weight to categories with many examples (like 'character' and 'general') | 
					
						
						|  | - Provides higher precision in most common use cases | 
					
						
						|  |  | 
					
						
						|  | **Performance from validation**: | 
					
						
						|  | - Micro F1: ~67.3% | 
					
						
						|  | - Macro F1: ~46.3% | 
					
						
						|  | - Threshold: ~0.614 | 
					
						
						|  | """, | 
					
						
						|  |  | 
					
						
						|  | "Macro Optimized": """ | 
					
						
						|  | ### Macro Optimized Profile | 
					
						
						|  |  | 
					
						
						|  | **Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size. | 
					
						
						|  |  | 
					
						
						|  | **When to use**: When balanced performance across all categories is important, including rare tags. | 
					
						
						|  |  | 
					
						
						|  | **Effects**: | 
					
						
						|  | - More balanced performance across all tag categories | 
					
						
						|  | - Better at detecting rare or unusual tags | 
					
						
						|  | - Generally has lower thresholds than micro-optimized | 
					
						
						|  |  | 
					
						
						|  | **Performance from validation**: | 
					
						
						|  | - Micro F1: ~60.9% | 
					
						
						|  | - Macro F1: ~50.6% | 
					
						
						|  | - Threshold: ~0.492 | 
					
						
						|  | """, | 
					
						
						|  |  | 
					
						
						|  | "Balanced": """ | 
					
						
						|  | ### Balanced Profile | 
					
						
						|  |  | 
					
						
						|  | **Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment. | 
					
						
						|  |  | 
					
						
						|  | **When to use**: For general-purpose tagging when you don't have specific recall or precision requirements. | 
					
						
						|  |  | 
					
						
						|  | **Effects**: | 
					
						
						|  | - Good middle ground between precision and recall | 
					
						
						|  | - Works well for most common use cases | 
					
						
						|  | - Default choice for most users | 
					
						
						|  |  | 
					
						
						|  | **Performance from validation**: | 
					
						
						|  | - Micro F1: ~67.3% | 
					
						
						|  | - Macro F1: ~46.3% | 
					
						
						|  | - Threshold: ~0.614 | 
					
						
						|  | """, | 
					
						
						|  |  | 
					
						
						|  | "Overall": """ | 
					
						
						|  | ### Overall Profile | 
					
						
						|  |  | 
					
						
						|  | **Technical definition**: Uses a single threshold value across all categories. | 
					
						
						|  |  | 
					
						
						|  | **When to use**: When you want consistent behavior across all categories and a simple approach. | 
					
						
						|  |  | 
					
						
						|  | **Effects**: | 
					
						
						|  | - Consistent tagging threshold for all categories | 
					
						
						|  | - Simpler to understand than category-specific thresholds | 
					
						
						|  | - User-adjustable with a single slider | 
					
						
						|  |  | 
					
						
						|  | **Default threshold value**: 0.5 (user-adjustable) | 
					
						
						|  |  | 
					
						
						|  | **Note**: The threshold value is user-adjustable with the slider below. | 
					
						
						|  | """, | 
					
						
						|  |  | 
					
						
						|  | "Category-specific": """ | 
					
						
						|  | ### Category-specific Profile | 
					
						
						|  |  | 
					
						
						|  | **Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning. | 
					
						
						|  |  | 
					
						
						|  | **When to use**: When you want to customize tagging sensitivity for different categories. | 
					
						
						|  |  | 
					
						
						|  | **Effects**: | 
					
						
						|  | - Each category has its own independent threshold | 
					
						
						|  | - Full control over category sensitivity | 
					
						
						|  | - Best for fine-tuning results when some categories need different treatment | 
					
						
						|  |  | 
					
						
						|  | **Default threshold values**: Starts with balanced thresholds for each category | 
					
						
						|  |  | 
					
						
						|  | **Note**: Use the category sliders below to adjust thresholds for individual categories. | 
					
						
						|  | """ | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def load_validation_results(results_path): | 
					
						
						|  | """Load validation results from JSON file""" | 
					
						
						|  | try: | 
					
						
						|  | with open(results_path, 'r') as f: | 
					
						
						|  | data = json.load(f) | 
					
						
						|  | return data | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error loading validation results: {e}") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def extract_thresholds_from_results(validation_data): | 
					
						
						|  | """Extract threshold information from validation results""" | 
					
						
						|  | if not validation_data or 'results' not in validation_data: | 
					
						
						|  | return {} | 
					
						
						|  |  | 
					
						
						|  | thresholds = { | 
					
						
						|  | 'overall': {}, | 
					
						
						|  | 'categories': {} | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for result in validation_data['results']: | 
					
						
						|  | category = result['CATEGORY'].lower() | 
					
						
						|  | profile = result['PROFILE'].lower().replace(' ', '_') | 
					
						
						|  | threshold = result['THRESHOLD'] | 
					
						
						|  | micro_f1 = result['MICRO-F1'] | 
					
						
						|  | macro_f1 = result['MACRO-F1'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if profile == 'micro_opt': | 
					
						
						|  | profile = 'micro_optimized' | 
					
						
						|  | elif profile == 'macro_opt': | 
					
						
						|  | profile = 'macro_optimized' | 
					
						
						|  |  | 
					
						
						|  | threshold_info = { | 
					
						
						|  | 'threshold': threshold, | 
					
						
						|  | 'micro_f1': micro_f1, | 
					
						
						|  | 'macro_f1': macro_f1 | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | if category == 'overall': | 
					
						
						|  | thresholds['overall'][profile] = threshold_info | 
					
						
						|  | else: | 
					
						
						|  | if category not in thresholds['categories']: | 
					
						
						|  | thresholds['categories'][category] = {} | 
					
						
						|  | thresholds['categories'][category][profile] = threshold_info | 
					
						
						|  |  | 
					
						
						|  | return thresholds | 
					
						
						|  |  | 
					
						
						|  | def load_model_and_metadata(): | 
					
						
						|  | """Load model and metadata from available files""" | 
					
						
						|  |  | 
					
						
						|  | safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors") | 
					
						
						|  | safetensors_metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | validation_results_path = os.path.join(MODEL_DIR, "full_validation_results.json") | 
					
						
						|  |  | 
					
						
						|  | model_info = { | 
					
						
						|  | 'safetensors_available': os.path.exists(safetensors_path) and os.path.exists(safetensors_metadata_path), | 
					
						
						|  | 'onnx_available': os.path.exists(onnx_path) and os.path.exists(safetensors_metadata_path), | 
					
						
						|  | 'validation_results_available': os.path.exists(validation_results_path) | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | metadata = None | 
					
						
						|  | if os.path.exists(safetensors_metadata_path): | 
					
						
						|  | try: | 
					
						
						|  | with open(safetensors_metadata_path, 'r') as f: | 
					
						
						|  | metadata = json.load(f) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error loading metadata: {e}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | thresholds = {} | 
					
						
						|  | if model_info['validation_results_available']: | 
					
						
						|  | validation_data = load_validation_results(validation_results_path) | 
					
						
						|  | if validation_data: | 
					
						
						|  | thresholds = extract_thresholds_from_results(validation_data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not thresholds: | 
					
						
						|  | thresholds = { | 
					
						
						|  | 'overall': { | 
					
						
						|  | 'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0}, | 
					
						
						|  | 'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0}, | 
					
						
						|  | 'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0} | 
					
						
						|  | }, | 
					
						
						|  | 'categories': {} | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | return model_info, metadata, thresholds | 
					
						
						|  |  | 
					
						
						|  | def load_safetensors_model(safetensors_path, metadata_path): | 
					
						
						|  | """Load SafeTensors model""" | 
					
						
						|  | try: | 
					
						
						|  | from safetensors.torch import load_file | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(metadata_path, 'r') as f: | 
					
						
						|  | metadata = json.load(f) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from utils.model_loader import ImageTagger | 
					
						
						|  |  | 
					
						
						|  | model_info = metadata['model_info'] | 
					
						
						|  | dataset_info = metadata['dataset_info'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = ImageTagger( | 
					
						
						|  | total_tags=dataset_info['total_tags'], | 
					
						
						|  | dataset=None, | 
					
						
						|  | model_name=model_info['backbone'], | 
					
						
						|  | num_heads=model_info['num_attention_heads'], | 
					
						
						|  | dropout=0.0, | 
					
						
						|  | pretrained=False, | 
					
						
						|  | tag_context_size=model_info['tag_context_size'], | 
					
						
						|  | use_gradient_checkpointing=False, | 
					
						
						|  | img_size=model_info['img_size'] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | state_dict = load_file(safetensors_path) | 
					
						
						|  | model.load_state_dict(state_dict) | 
					
						
						|  | model.eval() | 
					
						
						|  |  | 
					
						
						|  | return model, metadata | 
					
						
						|  | except Exception as e: | 
					
						
						|  | raise Exception(f"Failed to load SafeTensors model: {e}") | 
					
						
						|  |  | 
					
						
						|  | def get_profile_metrics(thresholds, profile_name): | 
					
						
						|  | """Extract metrics for the given profile from the thresholds dictionary""" | 
					
						
						|  | profile_key = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if profile_name == "Micro Optimized": | 
					
						
						|  | profile_key = "micro_optimized" | 
					
						
						|  | elif profile_name == "Macro Optimized": | 
					
						
						|  | profile_key = "macro_optimized" | 
					
						
						|  | elif profile_name == "Balanced": | 
					
						
						|  | profile_key = "balanced" | 
					
						
						|  | elif profile_name in ["Overall", "Category-specific"]: | 
					
						
						|  | profile_key = "macro_optimized" | 
					
						
						|  |  | 
					
						
						|  | if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']: | 
					
						
						|  | return thresholds['overall'][profile_key] | 
					
						
						|  |  | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def on_threshold_profile_change(): | 
					
						
						|  | """Handle threshold profile changes""" | 
					
						
						|  | new_profile = st.session_state.threshold_profile | 
					
						
						|  |  | 
					
						
						|  | if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'): | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.settings['active_category_thresholds'] is None: | 
					
						
						|  | st.session_state.settings['active_category_thresholds'] = {} | 
					
						
						|  |  | 
					
						
						|  | current_thresholds = st.session_state.settings['active_category_thresholds'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | profile_key = None | 
					
						
						|  | if new_profile == "Micro Optimized": | 
					
						
						|  | profile_key = "micro_optimized" | 
					
						
						|  | elif new_profile == "Macro Optimized": | 
					
						
						|  | profile_key = "macro_optimized" | 
					
						
						|  | elif new_profile == "Balanced": | 
					
						
						|  | profile_key = "balanced" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']: | 
					
						
						|  | st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for category in st.session_state.categories: | 
					
						
						|  | if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]: | 
					
						
						|  | current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold'] | 
					
						
						|  | else: | 
					
						
						|  | current_thresholds[category] = st.session_state.settings['active_threshold'] | 
					
						
						|  |  | 
					
						
						|  | elif new_profile == "Overall": | 
					
						
						|  |  | 
					
						
						|  | if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']: | 
					
						
						|  | st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold'] | 
					
						
						|  | else: | 
					
						
						|  | st.session_state.settings['active_threshold'] = 0.5 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.session_state.settings['active_category_thresholds'] = {} | 
					
						
						|  |  | 
					
						
						|  | elif new_profile == "Category-specific": | 
					
						
						|  |  | 
					
						
						|  | if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']: | 
					
						
						|  | st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold'] | 
					
						
						|  | else: | 
					
						
						|  | st.session_state.settings['active_threshold'] = 0.5 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for category in st.session_state.categories: | 
					
						
						|  | if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]: | 
					
						
						|  | current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold'] | 
					
						
						|  | else: | 
					
						
						|  | current_thresholds[category] = st.session_state.settings['active_threshold'] | 
					
						
						|  |  | 
					
						
						|  | def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories): | 
					
						
						|  | """Apply thresholds to raw probabilities and return filtered tags""" | 
					
						
						|  | tags = {} | 
					
						
						|  | all_tags = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | active_category_thresholds = active_category_thresholds or {} | 
					
						
						|  |  | 
					
						
						|  | for category, cat_probs in all_probs.items(): | 
					
						
						|  |  | 
					
						
						|  | threshold = active_category_thresholds.get(category, active_threshold) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if selected_categories.get(category, True): | 
					
						
						|  | for tag, prob in tags[category]: | 
					
						
						|  | all_tags.append(tag) | 
					
						
						|  |  | 
					
						
						|  | return tags, all_tags | 
					
						
						|  |  | 
					
						
						|  | def image_tagger_app(): | 
					
						
						|  | """Main Streamlit application for image tagging.""" | 
					
						
						|  | st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️") | 
					
						
						|  |  | 
					
						
						|  | st.title("Camie-Tagger-v2 Interface") | 
					
						
						|  | st.markdown("---") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if 'settings' not in st.session_state: | 
					
						
						|  | st.session_state.settings = { | 
					
						
						|  | 'show_all_tags': False, | 
					
						
						|  | 'compact_view': True, | 
					
						
						|  | 'min_confidence': 0.01, | 
					
						
						|  | 'threshold_profile': "Macro", | 
					
						
						|  | 'active_threshold': 0.5, | 
					
						
						|  | 'active_category_thresholds': {}, | 
					
						
						|  | 'selected_categories': {}, | 
					
						
						|  | 'replace_underscores': False | 
					
						
						|  | } | 
					
						
						|  | st.session_state.show_profile_help = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if 'model_loaded' not in st.session_state: | 
					
						
						|  | st.session_state.model_loaded = False | 
					
						
						|  | st.session_state.model = None | 
					
						
						|  | st.session_state.thresholds = None | 
					
						
						|  | st.session_state.metadata = None | 
					
						
						|  | st.session_state.model_type = "onnx" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with st.sidebar: | 
					
						
						|  |  | 
					
						
						|  | st.subheader("💡 Notes") | 
					
						
						|  |  | 
					
						
						|  | st.markdown(""" | 
					
						
						|  | This tagger was trained on a subset of the available data due to hardware limitations. | 
					
						
						|  |  | 
					
						
						|  | A more comprehensive model trained on the full 3+ million image dataset would provide: | 
					
						
						|  | - More recent characters and tags. | 
					
						
						|  | - Improved accuracy. | 
					
						
						|  |  | 
					
						
						|  | If you find this tool useful and would like to support future development: | 
					
						
						|  | """) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.markdown(""" | 
					
						
						|  | <style> | 
					
						
						|  | @keyframes coffee-button-glow { | 
					
						
						|  | 0% { box-shadow: 0 0 5px #FFD700; } | 
					
						
						|  | 50% { box-shadow: 0 0 15px #FFD700; } | 
					
						
						|  | 100% { box-shadow: 0 0 5px #FFD700; } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | .coffee-button { | 
					
						
						|  | display: inline-block; | 
					
						
						|  | animation: coffee-button-glow 2s infinite; | 
					
						
						|  | border-radius: 5px; | 
					
						
						|  | transition: transform 0.3s ease; | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | .coffee-button:hover { | 
					
						
						|  | transform: scale(1.05); | 
					
						
						|  | } | 
					
						
						|  | </style> | 
					
						
						|  |  | 
					
						
						|  | <a href="https://ko-fi.com/camais" target="_blank" class="coffee-button"> | 
					
						
						|  | <img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" | 
					
						
						|  | alt="Buy Me A Coffee" | 
					
						
						|  | style="height: 45px; width: 162px; border-radius: 5px;" /> | 
					
						
						|  | </a> | 
					
						
						|  | """, unsafe_allow_html=True) | 
					
						
						|  |  | 
					
						
						|  | st.markdown(""" | 
					
						
						|  | Your support helps with: | 
					
						
						|  | - GPU costs for training | 
					
						
						|  | - Storage for larger datasets | 
					
						
						|  | - Development of new features | 
					
						
						|  | - Future projects | 
					
						
						|  |  | 
					
						
						|  | Thank you! 🙏 | 
					
						
						|  |  | 
					
						
						|  | Full Details: https://huggingface.co/Camais03/camie-tagger | 
					
						
						|  | """) | 
					
						
						|  |  | 
					
						
						|  | st.header("Model Selection") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_info, metadata, thresholds = load_model_and_metadata() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_options = [] | 
					
						
						|  | if model_info['onnx_available']: | 
					
						
						|  | model_options.append("ONNX (Recommended)") | 
					
						
						|  | if model_info['safetensors_available']: | 
					
						
						|  | model_options.append("SafeTensors (PyTorch)") | 
					
						
						|  |  | 
					
						
						|  | if not model_options: | 
					
						
						|  | st.error("No model files found!") | 
					
						
						|  | st.info(f"Looking for models in: {MODEL_DIR}") | 
					
						
						|  | st.info("Expected files:") | 
					
						
						|  | st.info("- camie-tagger-v2.onnx") | 
					
						
						|  | st.info("- camie-tagger-v2.safetensors") | 
					
						
						|  | st.info("- camie-tagger-v2-metadata.json") | 
					
						
						|  | st.stop() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | default_index = 0 if model_info['onnx_available'] else 0 | 
					
						
						|  | model_type = st.radio( | 
					
						
						|  | "Select Model Type:", | 
					
						
						|  | model_options, | 
					
						
						|  | index=default_index, | 
					
						
						|  | help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if model_type == "ONNX (Recommended)": | 
					
						
						|  | selected_model_type = "onnx" | 
					
						
						|  | else: | 
					
						
						|  | selected_model_type = "safetensors" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if selected_model_type != st.session_state.model_type: | 
					
						
						|  | st.session_state.model_loaded = False | 
					
						
						|  | st.session_state.model_type = selected_model_type | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.button("Reload Model") and st.session_state.model_loaded: | 
					
						
						|  | st.session_state.model_loaded = False | 
					
						
						|  | st.info("Reloading model...") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not st.session_state.model_loaded: | 
					
						
						|  | try: | 
					
						
						|  | with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."): | 
					
						
						|  | if st.session_state.model_type == "onnx": | 
					
						
						|  |  | 
					
						
						|  | import onnxruntime as ort | 
					
						
						|  |  | 
					
						
						|  | onnx_path = os.path.join(MODEL_DIR, "camie-tagger-v2.onnx") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | providers = ort.get_available_providers() | 
					
						
						|  | gpu_available = any('CUDA' in provider for provider in providers) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | session = ort.InferenceSession(onnx_path, providers=providers) | 
					
						
						|  |  | 
					
						
						|  | st.session_state.model = session | 
					
						
						|  | st.session_state.device = f"ONNX Runtime ({'GPU' if gpu_available else 'CPU'})" | 
					
						
						|  | st.session_state.param_dtype = "float32" | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | safetensors_path = os.path.join(MODEL_DIR, "camie-tagger-v2.safetensors") | 
					
						
						|  | metadata_path = os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json") | 
					
						
						|  |  | 
					
						
						|  | model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path) | 
					
						
						|  |  | 
					
						
						|  | st.session_state.model = model | 
					
						
						|  | device = next(model.parameters()).device | 
					
						
						|  | param_dtype = next(model.parameters()).dtype | 
					
						
						|  | st.session_state.device = device | 
					
						
						|  | st.session_state.param_dtype = param_dtype | 
					
						
						|  | metadata = loaded_metadata | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.session_state.thresholds = thresholds | 
					
						
						|  | st.session_state.metadata = metadata | 
					
						
						|  | st.session_state.model_loaded = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if metadata and 'dataset_info' in metadata: | 
					
						
						|  | tag_mapping = metadata['dataset_info']['tag_mapping'] | 
					
						
						|  | categories = list(set(tag_mapping['tag_to_category'].values())) | 
					
						
						|  | st.session_state.categories = categories | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not st.session_state.settings['selected_categories']: | 
					
						
						|  | st.session_state.settings['selected_categories'] = {cat: True for cat in categories} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if 'overall' in thresholds and 'balanced' in thresholds['overall']: | 
					
						
						|  | st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold'] | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | st.error(f"Error loading model: {str(e)}") | 
					
						
						|  | st.code(traceback.format_exc()) | 
					
						
						|  | st.stop() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with st.sidebar: | 
					
						
						|  | st.header("Model Information") | 
					
						
						|  | if st.session_state.model_loaded: | 
					
						
						|  | if st.session_state.model_type == "onnx": | 
					
						
						|  | st.success("Using ONNX Model") | 
					
						
						|  | else: | 
					
						
						|  | st.success("Using SafeTensors Model") | 
					
						
						|  |  | 
					
						
						|  | st.write(f"Device: {st.session_state.device}") | 
					
						
						|  | st.write(f"Precision: {st.session_state.param_dtype}") | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.metadata: | 
					
						
						|  | if 'dataset_info' in st.session_state.metadata: | 
					
						
						|  | total_tags = st.session_state.metadata['dataset_info']['total_tags'] | 
					
						
						|  | st.write(f"Total tags: {total_tags}") | 
					
						
						|  | elif 'total_tags' in st.session_state.metadata: | 
					
						
						|  | st.write(f"Total tags: {st.session_state.metadata['total_tags']}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with st.expander("Available Categories"): | 
					
						
						|  | for category in sorted(st.session_state.categories): | 
					
						
						|  | st.write(f"- {category.capitalize()}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with st.expander("About this app"): | 
					
						
						|  | st.write(""" | 
					
						
						|  | This app uses a trained image tagging model to analyze and tag images. | 
					
						
						|  |  | 
					
						
						|  | **Model Options**: | 
					
						
						|  | - **ONNX (Recommended)**: Optimized for inference speed with broad compatibility | 
					
						
						|  | - **SafeTensors**: Native PyTorch format for advanced users | 
					
						
						|  |  | 
					
						
						|  | **Features**: | 
					
						
						|  | - Upload or process images in batches | 
					
						
						|  | - Multiple threshold profiles based on validation results | 
					
						
						|  | - Category-specific threshold adjustment | 
					
						
						|  | - Export tags in various formats | 
					
						
						|  | - Fast inference with GPU acceleration (when available) | 
					
						
						|  |  | 
					
						
						|  | **Threshold Profiles**: | 
					
						
						|  | - **Micro Optimized**: Best overall F1 score (67.3% micro F1) | 
					
						
						|  | - **Macro Optimized**: Balanced across categories (50.6% macro F1) | 
					
						
						|  | - **Balanced**: Good general-purpose setting | 
					
						
						|  | - **Overall**: Single adjustable threshold | 
					
						
						|  | - **Category-specific**: Fine-tune each category individually | 
					
						
						|  | """) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | col1, col2 = st.columns([1, 1.5]) | 
					
						
						|  |  | 
					
						
						|  | with col1: | 
					
						
						|  | st.header("Image") | 
					
						
						|  |  | 
					
						
						|  | upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"]) | 
					
						
						|  |  | 
					
						
						|  | image_path = None | 
					
						
						|  |  | 
					
						
						|  | with upload_tab: | 
					
						
						|  | uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | 
					
						
						|  |  | 
					
						
						|  | if uploaded_file: | 
					
						
						|  |  | 
					
						
						|  | with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | 
					
						
						|  | tmp_file.write(uploaded_file.getvalue()) | 
					
						
						|  | image_path = tmp_file.name | 
					
						
						|  |  | 
					
						
						|  | st.session_state.original_filename = uploaded_file.name | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image = Image.open(uploaded_file) | 
					
						
						|  | st.image(image, use_container_width=True) | 
					
						
						|  |  | 
					
						
						|  | with batch_tab: | 
					
						
						|  | st.subheader("Batch Process Images") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | batch_folder = st.text_input("Enter folder path containing images:", "") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | save_options = st.radio( | 
					
						
						|  | "Where to save tag files:", | 
					
						
						|  | ["Same folder as images", "Custom location", "Default save folder"], | 
					
						
						|  | index=0 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.subheader("Performance Options") | 
					
						
						|  | batch_size = st.number_input("Batch size", min_value=1, max_value=32, value=4, | 
					
						
						|  | help="Higher values may improve speed but use more memory") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | enable_category_limits = st.checkbox("Limit tags per category in batch output", value=False) | 
					
						
						|  |  | 
					
						
						|  | if enable_category_limits and hasattr(st.session_state, 'categories'): | 
					
						
						|  | if 'category_limits' not in st.session_state: | 
					
						
						|  | st.session_state.category_limits = {} | 
					
						
						|  |  | 
					
						
						|  | st.markdown("**Limit Values:** -1 = no limit, 0 = exclude, N = top N tags") | 
					
						
						|  |  | 
					
						
						|  | limit_cols = st.columns(2) | 
					
						
						|  | for i, category in enumerate(sorted(st.session_state.categories)): | 
					
						
						|  | col_idx = i % 2 | 
					
						
						|  | with limit_cols[col_idx]: | 
					
						
						|  | current_limit = st.session_state.category_limits.get(category, -1) | 
					
						
						|  | new_limit = st.number_input( | 
					
						
						|  | f"{category.capitalize()}:", | 
					
						
						|  | value=current_limit, | 
					
						
						|  | min_value=-1, | 
					
						
						|  | step=1, | 
					
						
						|  | key=f"limit_{category}" | 
					
						
						|  | ) | 
					
						
						|  | st.session_state.category_limits[category] = new_limit | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if batch_folder and os.path.isdir(batch_folder): | 
					
						
						|  | image_files = [] | 
					
						
						|  | for ext in ['*.jpg', '*.jpeg', '*.png']: | 
					
						
						|  | image_files.extend(glob.glob(os.path.join(batch_folder, ext))) | 
					
						
						|  | image_files.extend(glob.glob(os.path.join(batch_folder, ext.upper()))) | 
					
						
						|  |  | 
					
						
						|  | if image_files: | 
					
						
						|  | st.write(f"Found {len(image_files)} images") | 
					
						
						|  |  | 
					
						
						|  | if st.button("🔄 Process All Images", type="primary"): | 
					
						
						|  | if not st.session_state.model_loaded: | 
					
						
						|  | st.error("Model not loaded") | 
					
						
						|  | else: | 
					
						
						|  | with st.spinner("Processing images..."): | 
					
						
						|  | progress_bar = st.progress(0) | 
					
						
						|  | status_text = st.empty() | 
					
						
						|  |  | 
					
						
						|  | def update_progress(current, total, image_path): | 
					
						
						|  | progress = current / total if total > 0 else 0 | 
					
						
						|  | progress_bar.progress(progress) | 
					
						
						|  | status_text.text(f"Processing {current}/{total}: {os.path.basename(image_path) if image_path else 'Complete'}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if save_options == "Same folder as images": | 
					
						
						|  | save_dir = batch_folder | 
					
						
						|  | elif save_options == "Custom location": | 
					
						
						|  | save_dir = st.text_input("Custom save directory:", batch_folder) | 
					
						
						|  | else: | 
					
						
						|  | save_dir = os.path.join(os.path.dirname(__file__), "saved_tags") | 
					
						
						|  | os.makedirs(save_dir, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | category_limits = st.session_state.category_limits if enable_category_limits else None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.model_type == "onnx": | 
					
						
						|  | batch_results = batch_process_images_onnx( | 
					
						
						|  | folder_path=batch_folder, | 
					
						
						|  | model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"), | 
					
						
						|  | metadata_path=os.path.join(MODEL_DIR, "camie-tagger-v2-metadata.json"), | 
					
						
						|  | threshold_profile=st.session_state.settings['threshold_profile'], | 
					
						
						|  | active_threshold=st.session_state.settings['active_threshold'], | 
					
						
						|  | active_category_thresholds=st.session_state.settings['active_category_thresholds'], | 
					
						
						|  | save_dir=save_dir, | 
					
						
						|  | progress_callback=update_progress, | 
					
						
						|  | min_confidence=st.session_state.settings['min_confidence'], | 
					
						
						|  | batch_size=batch_size, | 
					
						
						|  | category_limits=category_limits | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | st.error("SafeTensors batch processing not implemented yet") | 
					
						
						|  | batch_results = None | 
					
						
						|  |  | 
					
						
						|  | if batch_results: | 
					
						
						|  | display_batch_results(batch_results) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with col2: | 
					
						
						|  | st.header("Tagging Controls") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_profiles = [ | 
					
						
						|  | "Micro Optimized", | 
					
						
						|  | "Macro Optimized", | 
					
						
						|  | "Balanced", | 
					
						
						|  | "Overall", | 
					
						
						|  | "Category-specific" | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | profile_col1, profile_col2 = st.columns([3, 1]) | 
					
						
						|  |  | 
					
						
						|  | with profile_col1: | 
					
						
						|  | threshold_profile = st.selectbox( | 
					
						
						|  | "Select threshold profile", | 
					
						
						|  | options=all_profiles, | 
					
						
						|  | index=1, | 
					
						
						|  | key="threshold_profile", | 
					
						
						|  | on_change=on_threshold_profile_change | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with profile_col2: | 
					
						
						|  | if st.button("ℹ️ Help", key="profile_help"): | 
					
						
						|  | st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.get('show_profile_help', False): | 
					
						
						|  | st.markdown(threshold_profile_explanations[threshold_profile]) | 
					
						
						|  | else: | 
					
						
						|  | st.info(threshold_profile_descriptions[threshold_profile]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.model_loaded: | 
					
						
						|  | metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile) | 
					
						
						|  |  | 
					
						
						|  | if metrics: | 
					
						
						|  | metrics_cols = st.columns(3) | 
					
						
						|  |  | 
					
						
						|  | with metrics_cols[0]: | 
					
						
						|  | st.metric("Threshold", f"{metrics['threshold']:.3f}") | 
					
						
						|  |  | 
					
						
						|  | with metrics_cols[1]: | 
					
						
						|  | st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%") | 
					
						
						|  |  | 
					
						
						|  | with metrics_cols[2]: | 
					
						
						|  | st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.model_loaded: | 
					
						
						|  | active_threshold = st.session_state.settings.get('active_threshold', 0.5) | 
					
						
						|  | active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {}) | 
					
						
						|  |  | 
					
						
						|  | if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]: | 
					
						
						|  |  | 
					
						
						|  | st.slider( | 
					
						
						|  | "Threshold (from validation)", | 
					
						
						|  | min_value=0.01, | 
					
						
						|  | max_value=1.0, | 
					
						
						|  | value=float(active_threshold), | 
					
						
						|  | step=0.01, | 
					
						
						|  | disabled=True, | 
					
						
						|  | help="This threshold is optimized from validation results" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | elif threshold_profile == "Overall": | 
					
						
						|  |  | 
					
						
						|  | active_threshold = st.slider( | 
					
						
						|  | "Overall threshold", | 
					
						
						|  | min_value=0.01, | 
					
						
						|  | max_value=1.0, | 
					
						
						|  | value=float(active_threshold), | 
					
						
						|  | step=0.01 | 
					
						
						|  | ) | 
					
						
						|  | st.session_state.settings['active_threshold'] = active_threshold | 
					
						
						|  |  | 
					
						
						|  | elif threshold_profile == "Category-specific": | 
					
						
						|  |  | 
					
						
						|  | st.slider( | 
					
						
						|  | "Overall threshold (reference)", | 
					
						
						|  | min_value=0.01, | 
					
						
						|  | max_value=1.0, | 
					
						
						|  | value=float(active_threshold), | 
					
						
						|  | step=0.01, | 
					
						
						|  | disabled=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | st.write("Adjust thresholds for individual categories:") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | slider_cols = st.columns(2) | 
					
						
						|  |  | 
					
						
						|  | if not active_category_thresholds: | 
					
						
						|  | active_category_thresholds = {} | 
					
						
						|  |  | 
					
						
						|  | for i, category in enumerate(sorted(st.session_state.categories)): | 
					
						
						|  | col_idx = i % 2 | 
					
						
						|  | with slider_cols[col_idx]: | 
					
						
						|  | default_val = active_category_thresholds.get(category, active_threshold) | 
					
						
						|  | new_threshold = st.slider( | 
					
						
						|  | f"{category.capitalize()}", | 
					
						
						|  | min_value=0.01, | 
					
						
						|  | max_value=1.0, | 
					
						
						|  | value=float(default_val), | 
					
						
						|  | step=0.01, | 
					
						
						|  | key=f"slider_{category}" | 
					
						
						|  | ) | 
					
						
						|  | active_category_thresholds[category] = new_threshold | 
					
						
						|  |  | 
					
						
						|  | st.session_state.settings['active_category_thresholds'] = active_category_thresholds | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with st.expander("Display Options", expanded=False): | 
					
						
						|  | col1, col2 = st.columns(2) | 
					
						
						|  | with col1: | 
					
						
						|  | show_all_tags = st.checkbox("Show all tags (including below threshold)", | 
					
						
						|  | value=st.session_state.settings['show_all_tags']) | 
					
						
						|  | compact_view = st.checkbox("Compact view (hide progress bars)", | 
					
						
						|  | value=st.session_state.settings['compact_view']) | 
					
						
						|  | replace_underscores = st.checkbox("Replace underscores with spaces", | 
					
						
						|  | value=st.session_state.settings.get('replace_underscores', False)) | 
					
						
						|  |  | 
					
						
						|  | with col2: | 
					
						
						|  | min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5, | 
					
						
						|  | st.session_state.settings['min_confidence'], 0.01) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.session_state.settings.update({ | 
					
						
						|  | 'show_all_tags': show_all_tags, | 
					
						
						|  | 'compact_view': compact_view, | 
					
						
						|  | 'min_confidence': min_confidence, | 
					
						
						|  | 'replace_underscores': replace_underscores | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.write("Categories to include in 'All Tags' section:") | 
					
						
						|  |  | 
					
						
						|  | category_cols = st.columns(3) | 
					
						
						|  | selected_categories = {} | 
					
						
						|  |  | 
					
						
						|  | if hasattr(st.session_state, 'categories'): | 
					
						
						|  | for i, category in enumerate(sorted(st.session_state.categories)): | 
					
						
						|  | col_idx = i % 3 | 
					
						
						|  | with category_cols[col_idx]: | 
					
						
						|  | default_val = st.session_state.settings['selected_categories'].get(category, True) | 
					
						
						|  | selected_categories[category] = st.checkbox( | 
					
						
						|  | f"{category.capitalize()}", | 
					
						
						|  | value=default_val, | 
					
						
						|  | key=f"cat_select_{category}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | st.session_state.settings['selected_categories'] = selected_categories | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if image_path and st.button("Run Tagging"): | 
					
						
						|  | if not st.session_state.model_loaded: | 
					
						
						|  | st.error("Model not loaded") | 
					
						
						|  | else: | 
					
						
						|  | with st.spinner("Analyzing image..."): | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.model_type == "onnx": | 
					
						
						|  | from utils.onnx_processing import process_single_image_onnx | 
					
						
						|  |  | 
					
						
						|  | result = process_single_image_onnx( | 
					
						
						|  | image_path=image_path, | 
					
						
						|  | model_path=os.path.join(MODEL_DIR, "camie-tagger-v2.onnx"), | 
					
						
						|  | metadata=st.session_state.metadata, | 
					
						
						|  | threshold_profile=threshold_profile, | 
					
						
						|  | active_threshold=st.session_state.settings['active_threshold'], | 
					
						
						|  | active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}), | 
					
						
						|  | min_confidence=st.session_state.settings['min_confidence'] | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | result = process_image( | 
					
						
						|  | image_path=image_path, | 
					
						
						|  | model=st.session_state.model, | 
					
						
						|  | thresholds=st.session_state.thresholds, | 
					
						
						|  | metadata=st.session_state.metadata, | 
					
						
						|  | threshold_profile=threshold_profile, | 
					
						
						|  | active_threshold=st.session_state.settings['active_threshold'], | 
					
						
						|  | active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}), | 
					
						
						|  | min_confidence=st.session_state.settings['min_confidence'] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if result['success']: | 
					
						
						|  | st.session_state.all_probs = result['all_probs'] | 
					
						
						|  | st.session_state.tags = result['tags'] | 
					
						
						|  | st.session_state.all_tags = result['all_tags'] | 
					
						
						|  | st.success("Analysis completed!") | 
					
						
						|  | else: | 
					
						
						|  | st.error(f"Analysis failed: {result.get('error', 'Unknown error')}") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | st.error(f"Error during analysis: {str(e)}") | 
					
						
						|  | st.code(traceback.format_exc()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if image_path and hasattr(st.session_state, 'all_probs'): | 
					
						
						|  | st.header("Predictions") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | filtered_tags, current_all_tags = apply_thresholds( | 
					
						
						|  | st.session_state.all_probs, | 
					
						
						|  | threshold_profile, | 
					
						
						|  | st.session_state.settings['active_threshold'], | 
					
						
						|  | st.session_state.settings.get('active_category_thresholds', {}), | 
					
						
						|  | st.session_state.settings['min_confidence'], | 
					
						
						|  | st.session_state.settings['selected_categories'] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | all_tags = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for category in sorted(st.session_state.all_probs.keys()): | 
					
						
						|  | all_tags_in_category = st.session_state.all_probs.get(category, []) | 
					
						
						|  | filtered_tags_in_category = filtered_tags.get(category, []) | 
					
						
						|  |  | 
					
						
						|  | if all_tags_in_category: | 
					
						
						|  | expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)" | 
					
						
						|  |  | 
					
						
						|  | with st.expander(expander_label, expanded=True): | 
					
						
						|  |  | 
					
						
						|  | active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {} | 
					
						
						|  | threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold']) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.settings['show_all_tags']: | 
					
						
						|  | tags_to_display = all_tags_in_category | 
					
						
						|  | else: | 
					
						
						|  | tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold] | 
					
						
						|  |  | 
					
						
						|  | if not tags_to_display: | 
					
						
						|  | st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence") | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if st.session_state.settings['compact_view']: | 
					
						
						|  |  | 
					
						
						|  | tag_list = [] | 
					
						
						|  | replace_underscores = st.session_state.settings.get('replace_underscores', False) | 
					
						
						|  |  | 
					
						
						|  | for tag, prob in tags_to_display: | 
					
						
						|  | percentage = int(prob * 100) | 
					
						
						|  | display_tag = tag.replace('_', ' ') if replace_underscores else tag | 
					
						
						|  | tag_list.append(f"{display_tag} ({percentage}%)") | 
					
						
						|  |  | 
					
						
						|  | if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True): | 
					
						
						|  | all_tags.append(tag) | 
					
						
						|  |  | 
					
						
						|  | st.markdown(", ".join(tag_list)) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | for tag, prob in tags_to_display: | 
					
						
						|  | replace_underscores = st.session_state.settings.get('replace_underscores', False) | 
					
						
						|  | display_tag = tag.replace('_', ' ') if replace_underscores else tag | 
					
						
						|  |  | 
					
						
						|  | if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True): | 
					
						
						|  | all_tags.append(tag) | 
					
						
						|  | tag_display = f"**{display_tag}**" | 
					
						
						|  | else: | 
					
						
						|  | tag_display = display_tag | 
					
						
						|  |  | 
					
						
						|  | st.write(tag_display) | 
					
						
						|  | st.markdown(display_progress_bar(prob), unsafe_allow_html=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.markdown("---") | 
					
						
						|  | st.subheader(f"All Tags ({len(all_tags)} total)") | 
					
						
						|  | if all_tags: | 
					
						
						|  | replace_underscores = st.session_state.settings.get('replace_underscores', False) | 
					
						
						|  | if replace_underscores: | 
					
						
						|  | display_tags = [tag.replace('_', ' ') for tag in all_tags] | 
					
						
						|  | st.write(", ".join(display_tags)) | 
					
						
						|  | else: | 
					
						
						|  | st.write(", ".join(all_tags)) | 
					
						
						|  | else: | 
					
						
						|  | st.info("No tags detected above the threshold.") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.markdown("---") | 
					
						
						|  | st.subheader("Save Tags") | 
					
						
						|  |  | 
					
						
						|  | if 'custom_folders' not in st.session_state: | 
					
						
						|  | st.session_state.custom_folders = get_default_save_locations() | 
					
						
						|  |  | 
					
						
						|  | selected_folder = st.selectbox( | 
					
						
						|  | "Select save location:", | 
					
						
						|  | options=st.session_state.custom_folders, | 
					
						
						|  | format_func=lambda x: os.path.basename(x) if os.path.basename(x) else x | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if st.button("💾 Save to Selected Location"): | 
					
						
						|  | try: | 
					
						
						|  | original_filename = st.session_state.original_filename if hasattr(st.session_state, 'original_filename') else None | 
					
						
						|  |  | 
					
						
						|  | saved_path = save_tags_to_file( | 
					
						
						|  | image_path=image_path, | 
					
						
						|  | all_tags=all_tags, | 
					
						
						|  | original_filename=original_filename, | 
					
						
						|  | custom_dir=selected_folder, | 
					
						
						|  | overwrite=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | st.success(f"Tags saved to: {os.path.basename(saved_path)}") | 
					
						
						|  | st.info(f"Full path: {saved_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with st.expander("File Contents", expanded=True): | 
					
						
						|  | with open(saved_path, 'r', encoding='utf-8') as f: | 
					
						
						|  | content = f.read() | 
					
						
						|  | st.code(content, language='text') | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | st.error(f"Error saving tags: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | image_tagger_app() |