Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoTokenizer, T5Tokenizer | |
| import asyncio | |
| import threading | |
| from concurrent.futures import ThreadPoolExecutor | |
| import time | |
| # Fixed list of custom tokenizers (left) | |
| TOKENIZER_CUSTOM = { | |
| "T5 Extended": "alakxender/dhivehi-T5-tokenizer-extended", | |
| "RoBERTa Extended": "alakxender/dhivehi-roberta-tokenizer-extended", | |
| "Google mT5": "google/mt5-base", | |
| "Google mT5 Extended": "alakxender/mt5-dhivehi-tokenizer-extended", | |
| "DeBERTa Extended": "alakxender/deberta-dhivehi-tokenizer-extended", | |
| "XLM-RoBERTa Extended": "alakxender/xlmr-dhivehi-tokenizer-extended", | |
| "Bert Extended": "alakxender/bert-dhivehi-tokenizer-extended", | |
| "Bert Extended Fast": "alakxender/bert-fast-dhivehi-tokenizer-extended" | |
| } | |
| # Suggested stock model paths for the right input | |
| SUGGESTED_STOCK_PATHS = [ | |
| "google/flan-t5-base", | |
| "t5-small", | |
| "t5-base", | |
| "t5-large", | |
| "google/mt5-base", | |
| "microsoft/trocr-base-handwritten", | |
| "microsoft/trocr-base-printed", | |
| "microsoft/deberta-v3-base" | |
| "xlm-roberta-base", | |
| "naver-clova-ix/donut-base", | |
| "bert-base-multilingual-cased" | |
| ] | |
| # Cache for loaded tokenizers to avoid reloading | |
| tokenizer_cache = {} | |
| # Load tokenizer with fallback to slow T5 | |
| def load_tokenizer(tokenizer_path): | |
| if tokenizer_path in tokenizer_cache: | |
| return tokenizer_cache[tokenizer_path] | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
| tokenizer_cache[tokenizer_path] = tokenizer | |
| return tokenizer | |
| except Exception: | |
| if "t5" in tokenizer_path.lower() or "mt5" in tokenizer_path.lower(): | |
| tokenizer = T5Tokenizer.from_pretrained(tokenizer_path) | |
| tokenizer_cache[tokenizer_path] = tokenizer | |
| return tokenizer | |
| raise | |
| # Tokenize and decode with enhanced visualization | |
| def tokenize_display(text, tokenizer_path): | |
| try: | |
| tokenizer = load_tokenizer(tokenizer_path) | |
| encoding = tokenizer(text, return_offsets_mapping=False, add_special_tokens=True) | |
| tokens = tokenizer.convert_ids_to_tokens(encoding.input_ids) | |
| ids = encoding.input_ids | |
| decoded = tokenizer.decode(ids, skip_special_tokens=False) | |
| return tokens, ids, decoded | |
| except Exception as e: | |
| return [f"[ERROR] {str(e)}"], [], "[Tokenizer Error]" | |
| def create_token_visualization(tokens, ids): | |
| """Create a visual representation of tokens with colors and spacing""" | |
| if not tokens or not ids: | |
| return "❌ No tokens to display" | |
| # Create colored token blocks | |
| token_blocks = [] | |
| colors = ["🟦", "🟩", "🟨", "🟪", "🟧", "🟫"] | |
| for i, (token, token_id) in enumerate(zip(tokens, ids)): | |
| color = colors[i % len(colors)] | |
| # Clean token display (remove special characters for better readability) | |
| clean_token = token.replace('▁', '_').replace('</s>', '[END]').replace('<s>', '[START]') | |
| token_blocks.append(f"{color} `{clean_token}` ({token_id})") | |
| return " ".join(token_blocks) | |
| # Async comparison with progress updates | |
| def compare_side_by_side_with_progress(dv_text, en_text, custom_label, stock_path, progress=gr.Progress()): | |
| def format_block(title, tokenizer_path): | |
| dv_tokens, dv_ids, dv_decoded = tokenize_display(dv_text, tokenizer_path) | |
| en_tokens, en_ids, en_decoded = tokenize_display(en_text, tokenizer_path) | |
| return f"""\ | |
| ## 🔤 {title} | |
| ### 🈁 Dhivehi: `{dv_text}` | |
| **🎯 Tokens:** {len(dv_tokens) if dv_ids else 'N/A'} tokens | |
| {create_token_visualization(dv_tokens, dv_ids)} | |
| **🔢 Token IDs:** `{dv_ids if dv_ids else '[ERROR]'}` | |
| **🔄 Decoded:** `{dv_decoded}` | |
| --- | |
| ### 🇬🇧 English: `{en_text}` | |
| **🎯 Tokens:** {len(en_tokens) if en_ids else 'N/A'} tokens | |
| {create_token_visualization(en_tokens, en_ids)} | |
| **🔢 Token IDs:** `{en_ids if en_ids else '[ERROR]'}` | |
| **🔄 Decoded:** `{en_decoded}` | |
| --- | |
| """ | |
| try: | |
| custom_path = TOKENIZER_CUSTOM[custom_label] | |
| except KeyError: | |
| return "[ERROR] Invalid custom tokenizer selected", "" | |
| # Show loading progress | |
| progress(0.1, desc="Loading custom tokenizer...") | |
| # Load custom tokenizer | |
| try: | |
| custom_result = format_block("Custom Tokenizer", custom_path) | |
| progress(0.5, desc="Custom tokenizer loaded. Loading stock tokenizer...") | |
| except Exception as e: | |
| custom_result = f"[ERROR] Failed to load custom tokenizer: {str(e)}" | |
| progress(0.5, desc="Custom tokenizer failed. Loading stock tokenizer...") | |
| # Load stock tokenizer | |
| try: | |
| stock_result = format_block("Stock Tokenizer", stock_path) | |
| progress(1.0, desc="Complete!") | |
| except Exception as e: | |
| stock_result = f"[ERROR] Failed to load stock tokenizer: {str(e)}" | |
| progress(1.0, desc="Complete with errors!") | |
| return custom_result, stock_result | |
| # Non-blocking comparison function | |
| def compare_tokenizers_async(dv_text, en_text, custom_label, stock_path): | |
| # Return immediate loading message | |
| loading_msg = """ | |
| ## ⏳ Loading Tokenizer... | |
| 🚀 **Status:** Downloading and initializing tokenizer... | |
| *This may take a moment for first-time downloads* | |
| """ | |
| # Use ThreadPoolExecutor for non-blocking execution | |
| with ThreadPoolExecutor(max_workers=2) as executor: | |
| future = executor.submit(compare_side_by_side_with_progress, dv_text, en_text, custom_label, stock_path) | |
| # Return loading state first | |
| yield loading_msg, loading_msg | |
| # Then return actual results | |
| try: | |
| custom_result, stock_result = future.result(timeout=120) # 2 minute timeout | |
| yield custom_result, stock_result | |
| except Exception as e: | |
| error_msg = f"## ❌ Error\n\n**Failed to load tokenizers:** {str(e)}" | |
| yield error_msg, error_msg | |
| # Gradio UI with better UX | |
| with gr.Blocks(title="Dhivehi Tokenizer Comparison Tool", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 🧠 Dhivehi Tokenizer Comparison") | |
| gr.Markdown("Compare how different tokenizers process Dhivehi and English input text.") | |
| with gr.Row(): | |
| dhivehi_text = gr.Textbox( | |
| label="Dhivehi Text", | |
| lines=2, | |
| value="އީދުގެ ހަރަކާތްތައް ފެށުމަށް މިރޭ ހުޅުމާލޭގައި އީދު މަޅި ރޯކުރަނީ", | |
| rtl=True, | |
| placeholder="Enter Dhivehi text here..." | |
| ) | |
| english_text = gr.Textbox( | |
| label="English Text", | |
| lines=2, | |
| value="The quick brown fox jumps over the lazy dog", | |
| placeholder="Enter English text here..." | |
| ) | |
| with gr.Row(): | |
| tokenizer_a = gr.Dropdown( | |
| label="Select Custom Tokenizer", | |
| choices=list(TOKENIZER_CUSTOM.keys()), | |
| value="T5 Extended", | |
| info="Pre-trained Dhivehi tokenizers (or paste a path)" | |
| ) | |
| tokenizer_b = gr.Dropdown( | |
| label="Enter or Select Stock Tokenizer Path", | |
| choices=SUGGESTED_STOCK_PATHS, | |
| value="google/flan-t5-base", | |
| allow_custom_value=True, | |
| info="Standard HuggingFace tokenizers (or paste a path)" | |
| ) | |
| compare_button = gr.Button("🔄 Compare Tokenizers", variant="primary", size="lg") | |
| with gr.Row(): | |
| output_custom = gr.Markdown(label="Custom Tokenizer Output", height=400) | |
| output_stock = gr.Markdown(label="Stock Tokenizer Output", height=400) | |
| # Use the non-blocking function | |
| compare_button.click( | |
| compare_side_by_side_with_progress, | |
| inputs=[dhivehi_text, english_text, tokenizer_a, tokenizer_b], | |
| outputs=[output_custom, output_stock], | |
| show_progress=True | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |