abdullahalmunem's picture
model added
f81cfe2
import os
import gradio as gr
import requests
import re
import time
import pandas as pd
from typing import Dict, Tuple, List, Optional
# Configuration
API_URL = "http://localhost:5685/punctuate"
punc_dict = {
'!': 'EXCLAMATION',
'?': 'QUESTION',
',': 'COMMA',
';': 'SEMICOLON',
':': 'COLON',
'-': 'HYPHEN',
'।': 'DARI',
}
allowed_punctuations = set(punc_dict.keys())
def clean_and_normalize_text(text, remove_punctuations=False):
"""Clean and normalize Bangla text with correct spacing"""
if remove_punctuations:
# Remove all allowed punctuations
cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text)
# Normalize spaces
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
return cleaned_text
else:
# Keep only allowed punctuations and Bangla letters/digits
chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
filtered_chunks = []
for chunk in chunks:
if chunk in allowed_punctuations:
filtered_chunks.append(chunk)
else:
# Clean text and preserve word boundaries
clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk)
clean_chunk = re.sub(r'\s+', ' ', clean_chunk) # Normalize internal spacing
clean_chunk = clean_chunk.strip()
if clean_chunk:
filtered_chunks.append(' ' + clean_chunk) # Add space before word chunks
# Join and clean up spacing
result = ''.join(filtered_chunks)
result = re.sub(r'\s+', ' ', result).strip()
return result
def restore_punctuation(text):
"""Call the punctuation restoration API"""
try:
payload = {"text": text}
start_time = time.time()
response = requests.post(API_URL, json=payload)
end_time = time.time()
api_time = end_time - start_time
if response.status_code == 200:
restored_text = response.json().get("restored_text")
return restored_text, api_time
else:
return f"API Error: {response.status_code} - {response.text}", api_time
except Exception as e:
return f"Connection Error: {str(e)}", 0.0
def dummy_restore_punctuation(text):
"""Dummy API call for demonstration when real API is not available"""
time.sleep(0.5) # Simulate API delay
# Simple dummy logic - add some punctuations randomly for demo
words = text.split()
if len(words) > 5:
words[2] = words[2] + ','
words[-1] = words[-1] + '?'
elif len(words) > 2:
words[-1] = words[-1] + '!'
return ' '.join(words), 0.5
def tokenize_with_punctuation(text):
"""Tokenize text keeping punctuation separate using chunk-based approach"""
tokens = []
chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
for chunk in chunks:
if not chunk.strip():
continue
if chunk in allowed_punctuations:
# This is a punctuation
tokens.append(chunk)
else:
# This is text, split into words
words = chunk.strip().split()
for word in words:
if word.strip():
tokens.append(word.strip())
return tokens
def compare_texts(ground_truth, predicted):
"""Compare ground truth and predicted text token by token with proper alignment"""
gt_tokens = tokenize_with_punctuation(ground_truth)
pred_tokens = tokenize_with_punctuation(predicted)
comparison_result = []
correct_puncs = {}
wrong_puncs = {}
gt_punc_counts = {}
# Count punctuations in ground truth
for token in gt_tokens:
if token in allowed_punctuations:
punc_name = punc_dict[token]
gt_punc_counts[punc_name] = gt_punc_counts.get(punc_name, 0) + 1
# Separate words and punctuations for better alignment
gt_words = [token for token in gt_tokens if token not in allowed_punctuations]
pred_words = [token for token in pred_tokens if token not in allowed_punctuations]
# Create position maps for punctuations
gt_punct_map = {} # word_index -> [punctuations after this word]
pred_punct_map = {} # word_index -> [punctuations after this word]
# Build ground truth punctuation map
word_idx = -1
for i, token in enumerate(gt_tokens):
if token not in allowed_punctuations:
word_idx += 1
else:
if word_idx not in gt_punct_map:
gt_punct_map[word_idx] = []
gt_punct_map[word_idx].append(token)
# Build predicted punctuation map
word_idx = -1
for i, token in enumerate(pred_tokens):
if token not in allowed_punctuations:
word_idx += 1
else:
if word_idx not in pred_punct_map:
pred_punct_map[word_idx] = []
pred_punct_map[word_idx].append(token)
# Compare words and punctuations
max_words = max(len(gt_words), len(pred_words))
for i in range(max_words):
# Add word
if i < len(gt_words) and i < len(pred_words):
if gt_words[i] == pred_words[i]:
comparison_result.append((gt_words[i], "correct", "black"))
else:
comparison_result.append((f"{gt_words[i]}{pred_words[i]}", "word_diff", "orange"))
elif i < len(gt_words):
comparison_result.append((f"{gt_words[i]}→''", "missing_word", "red"))
elif i < len(pred_words):
comparison_result.append((f"''→{pred_words[i]}", "extra_word", "red"))
# Compare punctuations after this word
gt_puncs = gt_punct_map.get(i, [])
pred_puncs = pred_punct_map.get(i, [])
# Handle punctuation comparison
max_puncs = max(len(gt_puncs), len(pred_puncs))
for j in range(max_puncs):
if j < len(gt_puncs) and j < len(pred_puncs):
gt_punc = gt_puncs[j]
pred_punc = pred_puncs[j]
if gt_punc == pred_punc:
punc_name = punc_dict[gt_punc]
correct_puncs[punc_name] = correct_puncs.get(punc_name, 0) + 1
comparison_result.append((gt_punc, "correct", "green"))
else:
# Wrong punctuation
punc_name = punc_dict[gt_punc]
wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1
comparison_result.append((f"{gt_punc}{pred_punc}", "wrong_punct", "red"))
elif j < len(gt_puncs):
# Missing punctuation
gt_punc = gt_puncs[j]
punc_name = punc_dict[gt_punc]
wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1
comparison_result.append((f"{gt_punc}→''", "missing_punct", "red"))
elif j < len(pred_puncs):
# Extra punctuation (not counted in wrong_puncs since it's not in GT)
pred_punc = pred_puncs[j]
comparison_result.append((f"''→{pred_punc}", "extra_punct", "red"))
return comparison_result, correct_puncs, wrong_puncs, gt_punc_counts
def create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts):
"""Create evaluation table"""
table_data = []
for punc_name in gt_punc_counts.keys():
correct_count = correct_puncs.get(punc_name, 0)
wrong_count = wrong_puncs.get(punc_name, 0)
total_count = gt_punc_counts[punc_name]
table_data.append([
punc_name,
correct_count,
wrong_count,
total_count
])
df = pd.DataFrame(table_data, columns=[
"Punctuation Name",
"Correctly Classified",
"Wrongly Classified",
"Count in Ground Truth"
])
return df
def format_comparison_html(comparison_result):
"""Format comparison result as HTML with improved display"""
html = "<div style='font-family: monospace; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 5px;'>"
for token, status, color in comparison_result:
if status == "correct" and color == "green":
# Correct punctuation
html += f"<span style='background-color: #d4edda; color: #155724; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>"
elif color == "red":
# Incorrect, missing, or extra punctuation/word
if "→''" in token:
# Missing punctuation or word
missing_item = token.split("→")[0]
html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{missing_item}→∅</span>"
elif "''→" in token:
# Extra punctuation or word
extra_item = token.split("→")[1]
html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>∅→{extra_item}</span>"
else:
# Wrong punctuation/word
html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>"
elif color == "orange":
# Word difference
html += f"<span style='background-color: #fff3cd; color: #856404; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>"
else:
# Correct word
html += f"<span style='padding: 2px 4px; margin: 1px;'>{token}</span>"
# Add space after each token
html += " "
html += "</div>"
# Add legend
html += """
<div style='margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 5px; font-size: 14px;'>
<strong>Legend:</strong><br>
<span style='background-color: #d4edda; color: #155724; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✓</span> Correct punctuation &nbsp;
<span style='background-color: #f8d7da; color: #721c24; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✗</span> Wrong/Missing/Extra punctuation &nbsp;
<span style='background-color: #fff3cd; color: #856404; padding: 1px 3px; border-radius: 2px; margin: 2px;'>~</span> Word difference &nbsp;
<span style='padding: 1px 3px; margin: 2px;'>◦</span> Correct word<br>
<strong>∅</strong> = Empty/Missing
</div>
"""
return html
def process_punctuation_restoration(input_text, ground_truth=""):
"""Main processing function"""
if not input_text.strip():
return "Please enter input text", "", "", None, ""
# Make API call (using dummy for demonstration)
try:
# Try real API first
predicted_text, api_time = restore_punctuation(input_text)
if "Error" in str(predicted_text):
# Fall back to dummy API
# predicted_text, api_time = dummy_restore_punctuation(input_text)
predicted_text, api_time = f"Error : {input_text}", 999999
except:
# Fall back to dummy API
# predicted_text, api_time = dummy_restore_punctuation(input_text)
predicted_text, api_time = f"Error : {input_text}", 999999
time_info = f"API call completed in {api_time:.3f} seconds"
predicted_text = predicted_text[0] if isinstance(predicted_text, list) else predicted_text
print(f"input_text: {input_text}", flush=True)
print(f"predicted_text: {predicted_text}", flush=True)
if not ground_truth.strip():
return predicted_text, "", time_info, None, ""
# Normalize ground truth
ground_truth_normalized = clean_and_normalize_text(ground_truth)
# Compare texts
comparison_result, correct_puncs, wrong_puncs, gt_punc_counts = compare_texts(
ground_truth_normalized, predicted_text
)
# Create comparison HTML
comparison_html = format_comparison_html(comparison_result)
# Create evaluation table
eval_table = create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts)
return predicted_text, comparison_html, time_info, eval_table, f"Normalized Ground Truth: {ground_truth_normalized}"
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Punctuation Restoration Evaluator", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🔤 Punctuation Restoration Evaluator")
gr.Markdown("Enter text to restore punctuation. Optionally provide ground truth for evaluation.")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Input Text (without punctuation)",
placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত",
lines=4
)
ground_truth = gr.Textbox(
label="Ground Truth (optional)",
placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?",
lines=4
)
submit_btn = gr.Button("🚀 Restore Punctuation", variant="primary")
with gr.Column(scale=2):
api_time = gr.Textbox(label="⏱️ API Response Time", interactive=False)
predicted_output = gr.Textbox(
label="📝 Predicted Output",
lines=3,
interactive=False
)
normalized_gt = gr.Textbox(
label="📋 Normalized Ground Truth",
lines=2,
interactive=False
)
comparison_output = gr.HTML(
label="🔍 Token-wise Comparison",
value="<p>Comparison will appear here after processing with ground truth.</p>"
)
evaluation_table = gr.Dataframe(
label="📊 Punctuation Evaluation Metrics",
headers=["Punctuation Name", "Correctly Classified", "Wrongly Classified", "Count in Ground Truth"],
interactive=False
)
# Legend
gr.Markdown("""
### 🎨 Color Legend:
- 🟢 **Green**: Correctly predicted punctuation
- 🔴 **Red**: Incorrectly predicted, missing, or extra punctuation/word
- 🟡 **Orange**: Word-level differences
- ⚫ **Black**: Correct words/tokens
- **∅**: Empty/Missing (instead of showing word→word or punct→word)
""")
submit_btn.click(
fn=process_punctuation_restoration,
inputs=[input_text, ground_truth],
outputs=[predicted_output, comparison_output, api_time, evaluation_table, normalized_gt]
)
# Example section
gr.Markdown("### 📚 Example")
gr.Examples(
examples=[
[
"পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত",
"পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?"
],
[
"ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান",
""
]
],
inputs=[input_text, ground_truth]
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True
)