File size: 16,756 Bytes
f81cfe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import os
import gradio as gr
import requests
import re
import time
import pandas as pd
from typing import Dict, Tuple, List, Optional

# Configuration
API_URL = "http://localhost:5685/punctuate"


punc_dict = { 
    '!': 'EXCLAMATION', 
    '?': 'QUESTION', 
    ',': 'COMMA', 
    ';': 'SEMICOLON', 
    ':': 'COLON', 
    '-': 'HYPHEN', 
    '।': 'DARI', 
}

allowed_punctuations = set(punc_dict.keys())

def clean_and_normalize_text(text, remove_punctuations=False): 
    """Clean and normalize Bangla text with correct spacing""" 
    if remove_punctuations: 
        # Remove all allowed punctuations 
        cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text) 
        # Normalize spaces 
        cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() 
        return cleaned_text 
    else: 
        # Keep only allowed punctuations and Bangla letters/digits 
        chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) 
        filtered_chunks = [] 
 
        for chunk in chunks: 
            if chunk in allowed_punctuations: 
                filtered_chunks.append(chunk) 
            else: 
                # Clean text and preserve word boundaries 
                clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk) 
                clean_chunk = re.sub(r'\s+', ' ', clean_chunk)  # Normalize internal spacing 
                clean_chunk = clean_chunk.strip() 
                if clean_chunk: 
                    filtered_chunks.append(' ' + clean_chunk)  # Add space before word chunks 
 
        # Join and clean up spacing 
        result = ''.join(filtered_chunks) 
        result = re.sub(r'\s+', ' ', result).strip() 
        return result 

def restore_punctuation(text):
    """Call the punctuation restoration API"""
    try:
        payload = {"text": text}
        start_time = time.time()
        response = requests.post(API_URL, json=payload)
        end_time = time.time()
        
        api_time = end_time - start_time
        
        if response.status_code == 200:
            restored_text = response.json().get("restored_text")
            return restored_text, api_time
        else:
            return f"API Error: {response.status_code} - {response.text}", api_time
    except Exception as e:
        return f"Connection Error: {str(e)}", 0.0

def dummy_restore_punctuation(text):
    """Dummy API call for demonstration when real API is not available"""
    time.sleep(0.5)  # Simulate API delay
    
    # Simple dummy logic - add some punctuations randomly for demo
    words = text.split()
    if len(words) > 5:
        words[2] = words[2] + ','
        words[-1] = words[-1] + '?'
    elif len(words) > 2:
        words[-1] = words[-1] + '!'
    
    return ' '.join(words), 0.5

def tokenize_with_punctuation(text):
    """Tokenize text keeping punctuation separate using chunk-based approach"""
    tokens = []
    chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text)
    
    for chunk in chunks:
        if not chunk.strip(): 
            continue
        
        if chunk in allowed_punctuations:
            # This is a punctuation
            tokens.append(chunk)
        else:
            # This is text, split into words
            words = chunk.strip().split()
            for word in words:
                if word.strip():
                    tokens.append(word.strip())
    
    return tokens

def compare_texts(ground_truth, predicted):
    """Compare ground truth and predicted text token by token with proper alignment"""
    gt_tokens = tokenize_with_punctuation(ground_truth)
    pred_tokens = tokenize_with_punctuation(predicted)
    
    comparison_result = []
    correct_puncs = {}
    wrong_puncs = {}
    gt_punc_counts = {}
    
    # Count punctuations in ground truth
    for token in gt_tokens:
        if token in allowed_punctuations:
            punc_name = punc_dict[token]
            gt_punc_counts[punc_name] = gt_punc_counts.get(punc_name, 0) + 1
    
    # Separate words and punctuations for better alignment
    gt_words = [token for token in gt_tokens if token not in allowed_punctuations]
    pred_words = [token for token in pred_tokens if token not in allowed_punctuations]
    
    # Create position maps for punctuations
    gt_punct_map = {}  # word_index -> [punctuations after this word]
    pred_punct_map = {}  # word_index -> [punctuations after this word]
    
    # Build ground truth punctuation map
    word_idx = -1
    for i, token in enumerate(gt_tokens):
        if token not in allowed_punctuations:
            word_idx += 1
        else:
            if word_idx not in gt_punct_map:
                gt_punct_map[word_idx] = []
            gt_punct_map[word_idx].append(token)
    
    # Build predicted punctuation map
    word_idx = -1
    for i, token in enumerate(pred_tokens):
        if token not in allowed_punctuations:
            word_idx += 1
        else:
            if word_idx not in pred_punct_map:
                pred_punct_map[word_idx] = []
            pred_punct_map[word_idx].append(token)
    
    # Compare words and punctuations
    max_words = max(len(gt_words), len(pred_words))
    
    for i in range(max_words):
        # Add word
        if i < len(gt_words) and i < len(pred_words):
            if gt_words[i] == pred_words[i]:
                comparison_result.append((gt_words[i], "correct", "black"))
            else:
                comparison_result.append((f"{gt_words[i]}{pred_words[i]}", "word_diff", "orange"))
        elif i < len(gt_words):
            comparison_result.append((f"{gt_words[i]}→''", "missing_word", "red"))
        elif i < len(pred_words):
            comparison_result.append((f"''→{pred_words[i]}", "extra_word", "red"))
        
        # Compare punctuations after this word
        gt_puncs = gt_punct_map.get(i, [])
        pred_puncs = pred_punct_map.get(i, [])
        
        # Handle punctuation comparison
        max_puncs = max(len(gt_puncs), len(pred_puncs))
        
        for j in range(max_puncs):
            if j < len(gt_puncs) and j < len(pred_puncs):
                gt_punc = gt_puncs[j]
                pred_punc = pred_puncs[j]
                
                if gt_punc == pred_punc:
                    punc_name = punc_dict[gt_punc]
                    correct_puncs[punc_name] = correct_puncs.get(punc_name, 0) + 1
                    comparison_result.append((gt_punc, "correct", "green"))
                else:
                    # Wrong punctuation
                    punc_name = punc_dict[gt_punc]
                    wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1
                    comparison_result.append((f"{gt_punc}{pred_punc}", "wrong_punct", "red"))
            
            elif j < len(gt_puncs):
                # Missing punctuation
                gt_punc = gt_puncs[j]
                punc_name = punc_dict[gt_punc]
                wrong_puncs[punc_name] = wrong_puncs.get(punc_name, 0) + 1
                comparison_result.append((f"{gt_punc}→''", "missing_punct", "red"))
            
            elif j < len(pred_puncs):
                # Extra punctuation (not counted in wrong_puncs since it's not in GT)
                pred_punc = pred_puncs[j]
                comparison_result.append((f"''→{pred_punc}", "extra_punct", "red"))
    
    return comparison_result, correct_puncs, wrong_puncs, gt_punc_counts

def create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts):
    """Create evaluation table"""
    table_data = []
    
    for punc_name in gt_punc_counts.keys():
        correct_count = correct_puncs.get(punc_name, 0)
        wrong_count = wrong_puncs.get(punc_name, 0)
        total_count = gt_punc_counts[punc_name]
        
        table_data.append([
            punc_name,
            correct_count,
            wrong_count,
            total_count
        ])
    
    df = pd.DataFrame(table_data, columns=[
        "Punctuation Name", 
        "Correctly Classified", 
        "Wrongly Classified", 
        "Count in Ground Truth"
    ])
    
    return df

def format_comparison_html(comparison_result):
    """Format comparison result as HTML with improved display"""
    html = "<div style='font-family: monospace; font-size: 16px; line-height: 1.8; padding: 20px; border: 1px solid #ddd; border-radius: 5px;'>"
    
    for token, status, color in comparison_result:
        if status == "correct" and color == "green":
            # Correct punctuation
            html += f"<span style='background-color: #d4edda; color: #155724; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>"
        elif color == "red":
            # Incorrect, missing, or extra punctuation/word
            if "→''" in token:
                # Missing punctuation or word
                missing_item = token.split("→")[0]
                html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{missing_item}→∅</span>"
            elif "''→" in token:
                # Extra punctuation or word
                extra_item = token.split("→")[1]
                html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>∅→{extra_item}</span>"
            else:
                # Wrong punctuation/word
                html += f"<span style='background-color: #f8d7da; color: #721c24; padding: 2px 4px; margin: 1px; border-radius: 3px; font-weight: bold;'>{token}</span>"
        elif color == "orange":
            # Word difference
            html += f"<span style='background-color: #fff3cd; color: #856404; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token}</span>"
        else:
            # Correct word
            html += f"<span style='padding: 2px 4px; margin: 1px;'>{token}</span>"
        
        # Add space after each token
        html += " "
    
    html += "</div>"
    
    # Add legend
    html += """
    <div style='margin-top: 15px; padding: 10px; background-color: #f8f9fa; border-radius: 5px; font-size: 14px;'>
        <strong>Legend:</strong><br>
        <span style='background-color: #d4edda; color: #155724; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✓</span> Correct punctuation &nbsp;
        <span style='background-color: #f8d7da; color: #721c24; padding: 1px 3px; border-radius: 2px; margin: 2px;'>✗</span> Wrong/Missing/Extra punctuation &nbsp;
        <span style='background-color: #fff3cd; color: #856404; padding: 1px 3px; border-radius: 2px; margin: 2px;'>~</span> Word difference &nbsp;
        <span style='padding: 1px 3px; margin: 2px;'>◦</span> Correct word<br>
        <strong>∅</strong> = Empty/Missing
    </div>
    """
    
    return html

def process_punctuation_restoration(input_text, ground_truth=""):
    """Main processing function"""
    if not input_text.strip():
        return "Please enter input text", "", "", None, ""
    
    # Make API call (using dummy for demonstration)
    try:
        # Try real API first
        predicted_text, api_time = restore_punctuation(input_text)
        if "Error" in str(predicted_text):
            # Fall back to dummy API
            # predicted_text, api_time = dummy_restore_punctuation(input_text)
            predicted_text, api_time = f"Error : {input_text}", 999999
    except:
        # Fall back to dummy API
        # predicted_text, api_time = dummy_restore_punctuation(input_text)
        predicted_text, api_time = f"Error : {input_text}", 999999
    
    time_info = f"API call completed in {api_time:.3f} seconds"
    
    predicted_text = predicted_text[0] if isinstance(predicted_text, list) else predicted_text
    
    print(f"input_text: {input_text}", flush=True)
    print(f"predicted_text: {predicted_text}", flush=True)
    if not ground_truth.strip():
        return predicted_text, "", time_info, None, ""
    
    # Normalize ground truth
    ground_truth_normalized = clean_and_normalize_text(ground_truth)
    
    # Compare texts
    comparison_result, correct_puncs, wrong_puncs, gt_punc_counts = compare_texts(
        ground_truth_normalized, predicted_text
    )
    
    # Create comparison HTML
    comparison_html = format_comparison_html(comparison_result)
    
    # Create evaluation table
    eval_table = create_evaluation_table(correct_puncs, wrong_puncs, gt_punc_counts)
    
    return predicted_text, comparison_html, time_info, eval_table, f"Normalized Ground Truth: {ground_truth_normalized}"

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Punctuation Restoration Evaluator", theme=gr.themes.Soft()) as app:
        gr.Markdown("# 🔤 Punctuation Restoration Evaluator")
        gr.Markdown("Enter text to restore punctuation. Optionally provide ground truth for evaluation.")
        
        with gr.Row():
            with gr.Column(scale=1):
                input_text = gr.Textbox(
                    label="Input Text (without punctuation)",
                    placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত",
                    lines=4
                )
                
                ground_truth = gr.Textbox(
                    label="Ground Truth (optional)",
                    placeholder="পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?",
                    lines=4
                )
                
                submit_btn = gr.Button("🚀 Restore Punctuation", variant="primary")
        
        with gr.Column(scale=2):
            api_time = gr.Textbox(label="⏱️ API Response Time", interactive=False)
            
            predicted_output = gr.Textbox(
                label="📝 Predicted Output",
                lines=3,
                interactive=False
            )
            
            normalized_gt = gr.Textbox(
                label="📋 Normalized Ground Truth",
                lines=2,
                interactive=False
            )
            
            comparison_output = gr.HTML(
                label="🔍 Token-wise Comparison",
                value="<p>Comparison will appear here after processing with ground truth.</p>"
            )
            
            evaluation_table = gr.Dataframe(
                label="📊 Punctuation Evaluation Metrics",
                headers=["Punctuation Name", "Correctly Classified", "Wrongly Classified", "Count in Ground Truth"],
                interactive=False
            )
        
        # Legend
        gr.Markdown("""
        ### 🎨 Color Legend:
        - 🟢 **Green**: Correctly predicted punctuation
        - 🔴 **Red**: Incorrectly predicted, missing, or extra punctuation/word
        - 🟡 **Orange**: Word-level differences
        - ⚫ **Black**: Correct words/tokens
        - **∅**: Empty/Missing (instead of showing word→word or punct→word)
        """)
        
        submit_btn.click(
            fn=process_punctuation_restoration,
            inputs=[input_text, ground_truth],
            outputs=[predicted_output, comparison_output, api_time, evaluation_table, normalized_gt]
        )
        
        # Example section
        gr.Markdown("### 📚 Example")
        gr.Examples(
            examples=[
                [
                    "পুরুষের সংখ্যা মোট জনসংখ্যার ৫২ এবং নারীর সংখ্যা ৪৮ শহরের সাক্ষরতার হার কত",
                    "পুরুষের সংখ্যা মোট জনসংখ্যার ৫২, এবং নারীর সংখ্যা ৪৮। শহরের সাক্ষরতার হার কত?"
                ],
                [
                    "ক্রিকেট বিশ্বের কাছে নিজের আগামীবার তা ভালোভাবেই পৌঁছে দিলেন পাকিস্তানের পেসার আমের জামান",
                    ""
                ]
            ],
            inputs=[input_text, ground_truth]
        )
    
    return app

if __name__ == "__main__":
    app = create_interface() 
    app.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        debug=True
    )