|
import os |
|
import gradio as gr |
|
import requests |
|
import re |
|
import time |
|
import pandas as pd |
|
from typing import Dict, Tuple, List, Optional |
|
|
|
|
|
API_URL = "http://localhost:5685/punctuate" |
|
|
|
|
|
punc_dict = { |
|
'!': 'EXCLAMATION', |
|
'?': 'QUESTION', |
|
',': 'COMMA', |
|
';': 'SEMICOLON', |
|
':': 'COLON', |
|
'-': 'HYPHEN', |
|
'।': 'DARI', |
|
} |
|
|
|
allowed_punctuations = set(punc_dict.keys()) |
|
|
|
def clean_and_normalize_text(text, remove_punctuations=False): |
|
"""Clean and normalize Bangla text with correct spacing""" |
|
if remove_punctuations: |
|
|
|
cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text) |
|
|
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() |
|
return cleaned_text |
|
else: |
|
|
|
chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) |
|
filtered_chunks = [] |
|
|
|
for chunk in chunks: |
|
if chunk in allowed_punctuations: |
|
filtered_chunks.append(chunk) |
|
else: |
|
|
|
clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk) |
|
clean_chunk = re.sub(r'\s+', ' ', clean_chunk) |
|
clean_chunk = clean_chunk.strip() |
|
if clean_chunk: |
|
filtered_chunks.append(' ' + clean_chunk) |
|
|
|
|
|
result = ''.join(filtered_chunks) |
|
result = re.sub(r'\s+', ' ', result).strip() |
|
return result |
|
|
|
def restore_punctuation(text): |
|
"""Call the punctuation restoration API""" |
|
try: |
|
payload = {"text": text} |
|
start_time = time.time() |
|
response = requests.post(API_URL, json=payload) |
|
end_time = time.time() |
|
|
|
api_time = end_time - start_time |
|
|
|
if response.status_code == 200: |
|
restored_text = response.json().get("restored_text") |
|
return restored_text, api_time |
|
else: |
|
return f"API Error: {response.status_code} - {response.text}", api_time |
|
except Exception as e: |
|
return f"Connection Error: {str(e)}", 0.0 |
|
|
|
def dummy_restore_punctuation(text): |
|
"""Dummy API call for demonstration when real API is not available""" |
|
time.sleep(0.5) |
|
|
|
|
|
words = text.split() |
|
if len(words) > 5: |
|
words[2] = words[2] + ',' |
|
words[-1] = words[-1] + '?' |
|
elif len(words) > 2: |
|
words[-1] = words[-1] + '!' |
|
|
|
return ' '.join(words), 0.5 |
|
|
|
def tokenize_with_punctuation(text): |
|
"""Tokenize text keeping punctuation separate using chunk-based approach""" |
|
tokens = [] |
|
chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) |
|
|
|
for chunk in chunks: |
|
if not chunk.strip(): |
|
continue |
|
|
|
if chunk in allowed_punctuations: |
|
|
|
tokens.append(chunk) |
|
else: |
|
|
|
words = chunk.strip().split() |
|
for word in words: |
|
if word.strip(): |
|
tokens.append(word.strip()) |
|
|
|
return tokens |
|
|
|
def compare_texts(ground_truth, predicted): |
|
"""Compare ground truth and predicted text token by token with proper alignment""" |
|
gt_tokens = tokenize_with_punctuation(ground_truth) |
|
pred_tokens = tokenize_with_punctuation(predicted) |
|
|
|
comparison_result = [] |
|
correct_puncs = {} |
|
wrong_puncs = {} |
|
gt_punc_counts = {} |
|
|
|
|
|
for token in gt_tokens: |
|
if token in allowed_punctuations: |
|
punc_name = punc_dict[token] |
|
gt_punc_counts[punc_name] = gt_punc_counts.get(punc_name, 0) + 1 |
|
|
|
|
|
gt_words = [token for token in gt_tokens if token not in allowed_punctuations] |
|
pred_words = [token for token in pred_tokens if token not in allowed_punctuations] |
|
|
|
|
|
gt_punct_map = {} |
|
pred_punct_map = {} |
|
|
|
|
|
word_idx = -1 |
|
for i, token in enumerate(gt_tokens): |
|
if token not in allowed_punctuations: |
|
word_idx += 1 |
|
else: |
|
if word_idx not in gt_punct_map: |
|
gt_punct_map[word_idx] = [] |
|
gt_punct_map[word_idx].append(token) |
|
|
|
|
|
word_idx = -1 |
|
for i, token in enumerate(pred_tokens): |
|
if token not in allowed_punctuations: |
|
word_idx += 1 |
|
else: |
|
if word_idx not in pred_punct_map: |
|
pred_punct_map[word_idx] = [] |
|
pred_punct_map[word_idx].append(token) |
|
|
|
|
|
max_words = max(len(gt_words), len(pred_words)) |
|
|
|
for i in range(max_words): |
|
|
|
if i < len(gt_words) and i < len(pred_words): |
|
if gt_words[i] == pred_words[i]: |
|
comparison_result.append((gt_words[i], "correct", "black")) |
|
else: |
|
comparison_result.append((f"{gt_words[i]}→{pred_words[i]}", "word_diff", "orange")) |
|
elif i < len(gt_words): |
|
comparison_result.append((f"{gt_words[i]}→''", "missing_word", "red")) |
|
elif i < len(pred_words): |
|
comparison_result.append((f"''→{pred_words[i]}", "extra_word", "red")) |
|
|
|
|
|
gt_puncs = gt_punct_map.get(i, []) |
|
pred_puncs = pred_punct_map.get(i, []) |
|
|
|
|
|
max_puncs = max(len(gt_puncs), len(pred_puncs)) |
|
|
|
for j in range(max_puncs): |
|
if j < len(gt_puncs) and j < len(pred_puncs): |
|
gt_punc = gt_puncs[j] |
|
pred_punc = pred_puncs[j] |
|
|
|
if gt_punc == pred_punc: |
|
punc_name = punc_dict[gt_punc] |
|
correct_puncs[punc_name] = correct_puncs.get(punc_name, 0) + 1 |
|
comparison_result.append((gt_punc, "correct", "green")) |
|
else: |
|
|
|
punc_name = punc_dict[gt_punc] |
|
wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1 |
|
comparison_result.append((f"{gt_punc}→{pred_punc}", "wrong_punct", "red")) |
|
|
|
elif j < len(gt_puncs): |
|
|
|
gt_punc = gt_puncs[j] |
|
punc_name = punc_dict[gt_punc] |
|
wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1 |
|
comparison_result.append((f"{gt_punc}→''", "missing_punct", "red")) |
|
|
|
elif j < len(pred_puncs): |
|
|
|
pred_punc = pred_puncs[j] |
|
comparison_result.append((f"''→{pred_punc}", "extra_punct", "red")) |
|
|
|
return comparison_result, correct_puncs, wrong_puncs, gt_punc_counts |
|
|
|
def create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts): |
|
"""Create evaluation table""" |
|
table_data = [] |
|
|
|
for punc_name in gt_punc_counts.keys(): |
|
correct_count = correct_puncs.get(punc_name, 0) |
|
wrong_count = wrong_puncs.get(punc_name, 0) |
|
total_count = gt_punc_counts[punc_name] |
|
|
|
table_data.append([ |
|
punc_name, |
|
correct_count, |
|
wrong_count, |
|
total_count |
|
]) |
|
|
|
df = pd.DataFrame(table_data, columns=[ |
|
"Punctuation Name", |
|
"Correctly Classified", |
|
"Wrongly Classified", |
|
"Count in Ground Truth" |
|
]) |
|
|
|
return df |
|
|
|
def format_comparison_html(comparison_result): |
|
"""Format comparison result as HTML with improved display""" |
|
html = "<div style='font-family: monospace; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 5px;'>" |
|
|
|
for token, status, color in comparison_result: |
|
if status == "correct" and color == "green": |
|
|
|
html += f"<span style='background-color: #d4edda; color: #155724; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>" |
|
elif color == "red": |
|
|
|
if "→''" in token: |
|
|
|
missing_item = token.split("→")[0] |
|
html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{missing_item}→∅</span>" |
|
elif "''→" in token: |
|
|
|
extra_item = token.split("→")[1] |
|
html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>∅→{extra_item}</span>" |
|
else: |
|
|
|
html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>" |
|
elif color == "orange": |
|
|
|
html += f"<span style='background-color: #fff3cd; color: #856404; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>" |
|
else: |
|
|
|
html += f"<span style='padding: 2px 4px; margin: 1px;'>{token}</span>" |
|
|
|
|
|
html += " " |
|
|
|
html += "</div>" |
|
|
|
|
|
html += """ |
|
<div style='margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 5px; font-size: 14px;'> |
|
<strong>Legend:</strong><br> |
|
<span style='background-color: #d4edda; color: #155724; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✓</span> Correct punctuation |
|
<span style='background-color: #f8d7da; color: #721c24; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✗</span> Wrong/Missing/Extra punctuation |
|
<span style='background-color: #fff3cd; color: #856404; padding: 1px 3px; border-radius: 2px; margin: 2px;'>~</span> Word difference |
|
<span style='padding: 1px 3px; margin: 2px;'>◦</span> Correct word<br> |
|
<strong>∅</strong> = Empty/Missing |
|
</div> |
|
""" |
|
|
|
return html |
|
|
|
def process_punctuation_restoration(input_text, ground_truth=""): |
|
"""Main processing function""" |
|
if not input_text.strip(): |
|
return "Please enter input text", "", "", None, "" |
|
|
|
|
|
try: |
|
|
|
predicted_text, api_time = restore_punctuation(input_text) |
|
if "Error" in str(predicted_text): |
|
|
|
|
|
predicted_text, api_time = f"Error : {input_text}", 999999 |
|
except: |
|
|
|
|
|
predicted_text, api_time = f"Error : {input_text}", 999999 |
|
|
|
time_info = f"API call completed in {api_time:.3f} seconds" |
|
|
|
predicted_text = predicted_text[0] if isinstance(predicted_text, list) else predicted_text |
|
|
|
print(f"input_text: {input_text}", flush=True) |
|
print(f"predicted_text: {predicted_text}", flush=True) |
|
if not ground_truth.strip(): |
|
return predicted_text, "", time_info, None, "" |
|
|
|
|
|
ground_truth_normalized = clean_and_normalize_text(ground_truth) |
|
|
|
|
|
comparison_result, correct_puncs, wrong_puncs, gt_punc_counts = compare_texts( |
|
ground_truth_normalized, predicted_text |
|
) |
|
|
|
|
|
comparison_html = format_comparison_html(comparison_result) |
|
|
|
|
|
eval_table = create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts) |
|
|
|
return predicted_text, comparison_html, time_info, eval_table, f"Normalized Ground Truth: {ground_truth_normalized}" |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Punctuation Restoration Evaluator", theme=gr.themes.Soft()) as app: |
|
gr.Markdown("# 🔤 Punctuation Restoration Evaluator") |
|
gr.Markdown("Enter text to restore punctuation. Optionally provide ground truth for evaluation.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_text = gr.Textbox( |
|
label="Input Text (without punctuation)", |
|
placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত", |
|
lines=4 |
|
) |
|
|
|
ground_truth = gr.Textbox( |
|
label="Ground Truth (optional)", |
|
placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?", |
|
lines=4 |
|
) |
|
|
|
submit_btn = gr.Button("🚀 Restore Punctuation", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
api_time = gr.Textbox(label="⏱️ API Response Time", interactive=False) |
|
|
|
predicted_output = gr.Textbox( |
|
label="📝 Predicted Output", |
|
lines=3, |
|
interactive=False |
|
) |
|
|
|
normalized_gt = gr.Textbox( |
|
label="📋 Normalized Ground Truth", |
|
lines=2, |
|
interactive=False |
|
) |
|
|
|
comparison_output = gr.HTML( |
|
label="🔍 Token-wise Comparison", |
|
value="<p>Comparison will appear here after processing with ground truth.</p>" |
|
) |
|
|
|
evaluation_table = gr.Dataframe( |
|
label="📊 Punctuation Evaluation Metrics", |
|
headers=["Punctuation Name", "Correctly Classified", "Wrongly Classified", "Count in Ground Truth"], |
|
interactive=False |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
### 🎨 Color Legend: |
|
- 🟢 **Green**: Correctly predicted punctuation |
|
- 🔴 **Red**: Incorrectly predicted, missing, or extra punctuation/word |
|
- 🟡 **Orange**: Word-level differences |
|
- ⚫ **Black**: Correct words/tokens |
|
- **∅**: Empty/Missing (instead of showing word→word or punct→word) |
|
""") |
|
|
|
submit_btn.click( |
|
fn=process_punctuation_restoration, |
|
inputs=[input_text, ground_truth], |
|
outputs=[predicted_output, comparison_output, api_time, evaluation_table, normalized_gt] |
|
) |
|
|
|
|
|
gr.Markdown("### 📚 Example") |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত", |
|
"পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?" |
|
], |
|
[ |
|
"ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান", |
|
"" |
|
] |
|
], |
|
inputs=[input_text, ground_truth] |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_interface() |
|
app.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
debug=True |
|
) |