Twin commited on
Commit
9ef957a
Β·
1 Parent(s): fa06dc7

Add PDF drag-and-drop functionality with Mistral OCR integration

Browse files
Files changed (5) hide show
  1. app.py +124 -6
  2. inference/ocr_service.py +126 -0
  3. static/index.html +320 -36
  4. test_ocr.py +133 -0
  5. uv.lock +0 -0
app.py CHANGED
@@ -11,7 +11,7 @@ import logging
11
  from contextlib import asynccontextmanager
12
  from typing import Dict, Any, List
13
 
14
- from fastapi import FastAPI, HTTPException, Request
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
17
  from pydantic import BaseModel, Field
@@ -19,6 +19,7 @@ from pydantic import BaseModel, Field
19
  # Import our inference services
20
  from inference.mistral_prompting import create_mistral_service, MistralPromptingService
21
  from inference.bert_classif import create_bert_service, BERTInferenceService
 
22
 
23
  # Setup logging
24
  logging.basicConfig(
@@ -31,6 +32,7 @@ logger = logging.getLogger(__name__)
31
  mistral_base_service: MistralPromptingService = None
32
  mistral_finetuned_service: MistralPromptingService = None
33
  bert_service: BERTInferenceService = None
 
34
 
35
  # Model configurations
36
  MODELS = {
@@ -44,7 +46,7 @@ BERT_MODEL_PATH = "SoelMgd/bert-pii-detection"
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
46
  """Manage application lifespan - startup and shutdown."""
47
- global mistral_base_service, mistral_finetuned_service, bert_service
48
 
49
  # Startup
50
  logger.info("πŸš€ Starting PII Masking Demo application...")
@@ -73,6 +75,16 @@ async def lifespan(app: FastAPI):
73
  except Exception as e:
74
  logger.error(f"Failed to initialize BERT service: {e}")
75
  # Don't raise exception - let app start but handle gracefully in endpoints
 
 
 
 
 
 
 
 
 
 
76
 
77
  yield
78
 
@@ -271,10 +283,109 @@ async def predict(request: PredictionRequest):
271
  detail=f"Prediction failed: {str(e)}"
272
  )
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  @app.get("/health", response_model=HealthResponse)
275
  async def health_check():
276
  """Health check endpoint."""
277
- global mistral_base_service, mistral_finetuned_service, bert_service
278
 
279
  services_status = {
280
  "mistral_base": {
@@ -294,6 +405,12 @@ async def health_check():
294
  "initialized": bert_service.is_initialized if bert_service else False,
295
  "model": BERT_MODEL_PATH,
296
  "info": bert_service.get_service_info() if bert_service else None
 
 
 
 
 
 
297
  }
298
  }
299
 
@@ -301,12 +418,13 @@ async def health_check():
301
  base_healthy = mistral_base_service and mistral_base_service.is_initialized
302
  finetuned_healthy = mistral_finetuned_service and mistral_finetuned_service.is_initialized
303
  bert_healthy = bert_service and bert_service.is_initialized
 
304
 
305
- healthy_services = sum([base_healthy, finetuned_healthy, bert_healthy])
306
 
307
- if healthy_services == 3:
308
  overall_status = "healthy"
309
- elif healthy_services >= 1:
310
  overall_status = "partial"
311
  else:
312
  overall_status = "degraded"
 
11
  from contextlib import asynccontextmanager
12
  from typing import Dict, Any, List
13
 
14
+ from fastapi import FastAPI, HTTPException, Request, File, UploadFile
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
17
  from pydantic import BaseModel, Field
 
19
  # Import our inference services
20
  from inference.mistral_prompting import create_mistral_service, MistralPromptingService
21
  from inference.bert_classif import create_bert_service, BERTInferenceService
22
+ from inference.ocr_service import create_ocr_service, OCRService
23
 
24
  # Setup logging
25
  logging.basicConfig(
 
32
  mistral_base_service: MistralPromptingService = None
33
  mistral_finetuned_service: MistralPromptingService = None
34
  bert_service: BERTInferenceService = None
35
+ ocr_service: OCRService = None
36
 
37
  # Model configurations
38
  MODELS = {
 
46
  @asynccontextmanager
47
  async def lifespan(app: FastAPI):
48
  """Manage application lifespan - startup and shutdown."""
49
+ global mistral_base_service, mistral_finetuned_service, bert_service, ocr_service
50
 
51
  # Startup
52
  logger.info("πŸš€ Starting PII Masking Demo application...")
 
75
  except Exception as e:
76
  logger.error(f"Failed to initialize BERT service: {e}")
77
  # Don't raise exception - let app start but handle gracefully in endpoints
78
+
79
+ try:
80
+ # Initialize OCR service
81
+ logger.info("Initializing OCR service...")
82
+ ocr_service = await create_ocr_service()
83
+ logger.info("βœ… OCR service initialized successfully")
84
+
85
+ except Exception as e:
86
+ logger.error(f"Failed to initialize OCR service: {e}")
87
+ # Don't raise exception - let app start but handle gracefully in endpoints
88
 
89
  yield
90
 
 
283
  detail=f"Prediction failed: {str(e)}"
284
  )
285
 
286
+ @app.post("/predict-pdf", response_model=PredictionResponse)
287
+ async def predict_pdf(
288
+ file: UploadFile = File(...),
289
+ method: str = "mistral",
290
+ model: str = "base",
291
+ pii_entities: str = "[]"
292
+ ):
293
+ """
294
+ Extract text from PDF using OCR, then predict PII entities and return masked text.
295
+
296
+ Supports the same methods as /predict: Mistral (base/fine-tuned) and BERT.
297
+ """
298
+ # Validate file type
299
+ if not file.filename.lower().endswith('.pdf'):
300
+ raise HTTPException(
301
+ status_code=400,
302
+ detail="Only PDF files are supported"
303
+ )
304
+
305
+ # Check OCR service availability
306
+ if ocr_service is None:
307
+ raise HTTPException(
308
+ status_code=503,
309
+ detail="OCR service not available. Please check API key configuration."
310
+ )
311
+
312
+ # Validate method
313
+ if method not in ["mistral", "bert"]:
314
+ raise HTTPException(
315
+ status_code=400,
316
+ detail=f"Method '{method}' not supported. Use 'mistral' or 'bert'."
317
+ )
318
+
319
+ try:
320
+ # Parse PII entities list
321
+ import json
322
+ pii_entities_list = json.loads(pii_entities) if pii_entities else []
323
+
324
+ start_time = time.time()
325
+
326
+ # Read PDF content
327
+ pdf_content = await file.read()
328
+ logger.info(f"πŸ“„ Received PDF file: {file.filename} ({len(pdf_content)} bytes)")
329
+
330
+ # Extract text using OCR
331
+ logger.info("πŸ” Extracting text from PDF using Mistral OCR...")
332
+ extracted_text = await ocr_service.extract_text_from_pdf(pdf_content)
333
+
334
+ if not extracted_text or len(extracted_text.strip()) < 10:
335
+ raise HTTPException(
336
+ status_code=400,
337
+ detail="Could not extract sufficient text from PDF"
338
+ )
339
+
340
+ logger.info(f"πŸ“ Extracted {len(extracted_text)} characters from PDF")
341
+
342
+ # Now process the extracted text with the selected method
343
+ if method == "mistral":
344
+ # Get the appropriate Mistral service
345
+ service = get_mistral_service(model)
346
+ prediction = await service.predict(extracted_text, pii_entities_list)
347
+ method_used = f"{method}-{model}"
348
+
349
+ elif method == "bert":
350
+ # Check BERT service availability
351
+ if bert_service is None:
352
+ raise HTTPException(
353
+ status_code=503,
354
+ detail="BERT service not available. Please check model configuration."
355
+ )
356
+
357
+ prediction = await bert_service.predict(extracted_text, pii_entities_list)
358
+ method_used = "bert"
359
+
360
+ processing_time = time.time() - start_time
361
+
362
+ # Count total entities
363
+ num_entities = sum(len(entities) for entities in prediction.entities.values())
364
+
365
+ logger.info(f"βœ… PDF processing completed in {processing_time:.3f}s - found {num_entities} entities")
366
+
367
+ return PredictionResponse(
368
+ masked_text=prediction.masked_text,
369
+ entities=prediction.entities,
370
+ processing_time=processing_time,
371
+ method_used=f"pdf-{method_used}",
372
+ num_entities=num_entities,
373
+ selected_entities=pii_entities_list
374
+ )
375
+
376
+ except HTTPException:
377
+ raise
378
+ except Exception as e:
379
+ logger.error(f"❌ PDF processing failed: {e}")
380
+ raise HTTPException(
381
+ status_code=500,
382
+ detail=f"PDF processing failed: {str(e)}"
383
+ )
384
+
385
  @app.get("/health", response_model=HealthResponse)
386
  async def health_check():
387
  """Health check endpoint."""
388
+ global mistral_base_service, mistral_finetuned_service, bert_service, ocr_service
389
 
390
  services_status = {
391
  "mistral_base": {
 
405
  "initialized": bert_service.is_initialized if bert_service else False,
406
  "model": BERT_MODEL_PATH,
407
  "info": bert_service.get_service_info() if bert_service else None
408
+ },
409
+ "ocr": {
410
+ "available": ocr_service is not None,
411
+ "initialized": ocr_service.is_initialized if ocr_service else False,
412
+ "model": "mistral-ocr-latest",
413
+ "info": ocr_service.get_service_info() if ocr_service else None
414
  }
415
  }
416
 
 
418
  base_healthy = mistral_base_service and mistral_base_service.is_initialized
419
  finetuned_healthy = mistral_finetuned_service and mistral_finetuned_service.is_initialized
420
  bert_healthy = bert_service and bert_service.is_initialized
421
+ ocr_healthy = ocr_service and ocr_service.is_initialized
422
 
423
+ healthy_services = sum([base_healthy, finetuned_healthy, bert_healthy, ocr_healthy])
424
 
425
+ if healthy_services == 4:
426
  overall_status = "healthy"
427
+ elif healthy_services >= 2:
428
  overall_status = "partial"
429
  else:
430
  overall_status = "degraded"
inference/ocr_service.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OCR Service for PDF processing using Mistral OCR API.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import logging
9
+ import base64
10
+ from pathlib import Path
11
+ from typing import Optional, Dict, Any
12
+ from dotenv import load_dotenv
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
+ from mistralai import Mistral
18
+
19
+ # Setup logging
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class OCRService:
23
+ """
24
+ OCR service for extracting text from PDF documents using Mistral OCR API.
25
+ """
26
+
27
+ def __init__(self, api_key: Optional[str] = None):
28
+ """
29
+ Initialize the OCR service.
30
+
31
+ Args:
32
+ api_key: Mistral API key (if None, will read from environment)
33
+ """
34
+ self.api_key = api_key or os.environ.get("MISTRAL_API_KEY")
35
+ if not self.api_key:
36
+ raise ValueError("MISTRAL_API_KEY not found in environment variables")
37
+
38
+ self.client = Mistral(api_key=self.api_key)
39
+ self.is_initialized = True
40
+
41
+ logger.info("πŸ”§ OCR service initialized with Mistral API")
42
+
43
+ async def extract_text_from_pdf(self, pdf_content: bytes) -> str:
44
+ """
45
+ Extract text from PDF content using Mistral OCR.
46
+
47
+ Args:
48
+ pdf_content: Raw PDF file content as bytes
49
+
50
+ Returns:
51
+ Extracted text content
52
+ """
53
+ try:
54
+ # Encode PDF content to base64
55
+ base64_pdf = base64.b64encode(pdf_content).decode('utf-8')
56
+
57
+ logger.info(f"πŸ“„ Processing PDF ({len(pdf_content)} bytes) with Mistral OCR...")
58
+
59
+ # Process the PDF with OCR
60
+ ocr_response = self.client.ocr.process(
61
+ model="mistral-ocr-latest",
62
+ document={
63
+ "type": "document_url",
64
+ "document_url": f"data:application/pdf;base64,{base64_pdf}"
65
+ },
66
+ include_image_base64=False # Don't include images to save bandwidth
67
+ )
68
+
69
+ logger.info("βœ… OCR processing completed")
70
+
71
+ # Extract text from all pages
72
+ extracted_text = ""
73
+
74
+ if hasattr(ocr_response, 'pages') and ocr_response.pages:
75
+ logger.info(f"πŸ“„ Found {len(ocr_response.pages)} pages")
76
+
77
+ for i, page in enumerate(ocr_response.pages):
78
+ if hasattr(page, 'markdown') and page.markdown:
79
+ page_text = page.markdown
80
+ extracted_text += page_text + "\n\n"
81
+ logger.debug(f"πŸ“ Page {i+1}: {len(page_text)} characters")
82
+
83
+ logger.info(f"πŸ“„ Total extracted text: {len(extracted_text)} characters")
84
+
85
+ if not extracted_text.strip():
86
+ logger.warning("⚠️ No text extracted from PDF")
87
+ return "No text could be extracted from this PDF."
88
+
89
+ return extracted_text.strip()
90
+
91
+ else:
92
+ logger.warning("⚠️ No pages found in OCR response")
93
+ return "No text could be extracted from this PDF."
94
+
95
+ except Exception as e:
96
+ logger.error(f"❌ OCR processing failed: {e}")
97
+ raise RuntimeError(f"Failed to extract text from PDF: {str(e)}")
98
+
99
+ def get_service_info(self) -> Dict[str, Any]:
100
+ """Get service information for monitoring."""
101
+ return {
102
+ "service_name": "OCRService",
103
+ "is_initialized": self.is_initialized,
104
+ "api_provider": "Mistral",
105
+ "model": "mistral-ocr-latest",
106
+ "description": "PDF text extraction using Mistral OCR API"
107
+ }
108
+
109
+ # Factory function for easy initialization
110
+ async def create_ocr_service(api_key: Optional[str] = None) -> OCRService:
111
+ """
112
+ Factory function to create and initialize OCR service.
113
+
114
+ Args:
115
+ api_key: Mistral API key (if None, will read from environment)
116
+
117
+ Returns:
118
+ Initialized OCRService
119
+ """
120
+ try:
121
+ service = OCRService(api_key)
122
+ logger.info("βœ… OCR service created successfully")
123
+ return service
124
+ except Exception as e:
125
+ logger.error(f"❌ Failed to create OCR service: {e}")
126
+ raise
static/index.html CHANGED
@@ -298,6 +298,131 @@
298
  background: #4338ca;
299
  }
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  .error {
302
  background: #fef2f2;
303
  color: #dc2626;
@@ -331,16 +456,43 @@
331
 
332
  <div class="content">
333
  <form id="piiForm">
334
- <div class="form-group">
335
- <label for="inputText">Enter your text:</label>
336
- <textarea
337
- id="inputText"
338
- class="input-textarea"
339
- placeholder="Enter text containing PII information...
 
 
 
 
 
 
 
 
 
 
340
 
341
  Example: Hi, my name is John Smith and my email is [email protected]. Call me at 555-1234."
342
- required
343
- ></textarea>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  </div>
345
 
346
  <div class="form-group">
@@ -349,14 +501,14 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
349
  <div class="method-option">
350
  <input type="radio" id="mistral" name="method" value="mistral" class="method-radio" checked onchange="toggleMistralModelSelection()">
351
  <label for="mistral" class="method-label">
352
- <div class="method-title">🧠 Mistral AI</div>
353
  <div class="method-desc">High accuracy via API</div>
354
  </label>
355
  </div>
356
  <div class="method-option">
357
  <input type="radio" id="bert" name="method" value="bert" class="method-radio" onchange="toggleMistralModelSelection()">
358
  <label for="bert" class="method-label">
359
- <div class="method-title">πŸ€– BERT</div>
360
  <div class="method-desc">Fast local processing</div>
361
  </label>
362
  </div>
@@ -369,14 +521,14 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
369
  <div class="method-option">
370
  <input type="radio" id="base" name="model" value="base" class="method-radio" checked>
371
  <label for="base" class="method-label">
372
- <div class="method-title">🎯 Base Model</div>
373
  <div class="method-desc">mistral-large-latest with detailed prompting</div>
374
  </label>
375
  </div>
376
  <div class="method-option">
377
  <input type="radio" id="finetuned" name="model" value="finetuned" class="method-radio">
378
  <label for="finetuned" class="method-label">
379
- <div class="method-title">🎯 Fine-tuned Model</div>
380
  <div class="method-desc">Specialized PII detection model</div>
381
  </label>
382
  </div>
@@ -384,7 +536,7 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
384
  </div>
385
 
386
  <div class="form-group pii-selection">
387
- <label>🎯 Select PII entities to mask:</label>
388
  <div class="pii-controls">
389
  <button type="button" class="pii-btn primary" onclick="selectAllPII()">Select All</button>
390
  <button type="button" class="pii-btn" onclick="selectNonePII()">Select None</button>
@@ -396,7 +548,7 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
396
  </div>
397
 
398
  <button type="submit" class="process-btn" id="processBtn">
399
- πŸš€ Process Text
400
  </button>
401
  </form>
402
 
@@ -407,12 +559,12 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
407
 
408
  <div class="result" id="result">
409
  <div class="result-section">
410
- <div class="result-title">🎭 Masked Text</div>
411
  <div class="result-content" id="maskedText"></div>
412
  </div>
413
 
414
  <div class="result-section">
415
- <div class="result-title">πŸ“‹ Detected Entities</div>
416
  <div class="result-content" id="entities"></div>
417
  </div>
418
 
@@ -462,9 +614,14 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
462
  'ADDRESS', 'CITY', 'STATE', 'ZIPCODE', 'DOB', 'AGE', 'IP'
463
  ];
464
 
 
 
 
 
465
  // Initialize PII selection on page load
466
  document.addEventListener('DOMContentLoaded', function() {
467
  initializePIISelection();
 
468
  });
469
 
470
  function initializePIISelection() {
@@ -519,19 +676,114 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
519
  return Array.from(checkboxes).map(cb => cb.value);
520
  }
521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  document.getElementById('piiForm').addEventListener('submit', async function(e) {
523
  e.preventDefault();
524
 
525
- const text = document.getElementById('inputText').value.trim();
526
  const method = document.querySelector('input[name="method"]:checked').value;
527
  const model = document.querySelector('input[name="model"]:checked').value;
528
  const selectedPIIEntities = getSelectedPIIEntities();
529
 
530
- if (!text) {
531
- showError('Please enter some text to analyze.');
532
- return;
533
- }
534
-
535
  if (selectedPIIEntities.length === 0) {
536
  showError('Please select at least one PII entity to mask.');
537
  return;
@@ -543,18 +795,50 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
543
  hideResult();
544
 
545
  try {
546
- const response = await fetch('/predict', {
547
- method: 'POST',
548
- headers: {
549
- 'Content-Type': 'application/json',
550
- },
551
- body: JSON.stringify({
552
- text: text,
553
- method: method,
554
- model: model,
555
- pii_entities: selectedPIIEntities
556
- })
557
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
  const result = await response.json();
560
 
@@ -578,11 +862,11 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
578
  if (loading) {
579
  loadingEl.style.display = 'block';
580
  processBtn.disabled = true;
581
- processBtn.textContent = '⏳ Processing...';
582
  } else {
583
  loadingEl.style.display = 'none';
584
  processBtn.disabled = false;
585
- processBtn.textContent = 'πŸš€ Process Text';
586
  }
587
  }
588
 
 
298
  background: #4338ca;
299
  }
300
 
301
+ .input-method-selection {
302
+ margin-bottom: 25px;
303
+ }
304
+
305
+ .input-tabs {
306
+ display: flex;
307
+ border-bottom: 2px solid #e5e7eb;
308
+ margin-bottom: 20px;
309
+ }
310
+
311
+ .input-tab {
312
+ padding: 12px 24px;
313
+ background: none;
314
+ border: none;
315
+ cursor: pointer;
316
+ font-size: 16px;
317
+ font-weight: 500;
318
+ color: #6b7280;
319
+ border-bottom: 2px solid transparent;
320
+ transition: all 0.3s ease;
321
+ }
322
+
323
+ .input-tab.active {
324
+ color: #4f46e5;
325
+ border-bottom-color: #4f46e5;
326
+ }
327
+
328
+ .input-tab:hover {
329
+ color: #4f46e5;
330
+ }
331
+
332
+ .input-content {
333
+ display: none;
334
+ }
335
+
336
+ .input-content.active {
337
+ display: block;
338
+ }
339
+
340
+ .dropzone {
341
+ border: 2px dashed #d1d5db;
342
+ border-radius: 12px;
343
+ padding: 40px 20px;
344
+ text-align: center;
345
+ background: #f9fafb;
346
+ transition: all 0.3s ease;
347
+ cursor: pointer;
348
+ }
349
+
350
+ .dropzone.dragover {
351
+ border-color: #4f46e5;
352
+ background: #f0f9ff;
353
+ }
354
+
355
+ .dropzone-icon {
356
+ font-size: 48px;
357
+ color: #9ca3af;
358
+ margin-bottom: 16px;
359
+ }
360
+
361
+ .dropzone.dragover .dropzone-icon {
362
+ color: #4f46e5;
363
+ }
364
+
365
+ .dropzone-text {
366
+ color: #374151;
367
+ font-size: 16px;
368
+ margin-bottom: 8px;
369
+ }
370
+
371
+ .dropzone-subtext {
372
+ color: #6b7280;
373
+ font-size: 14px;
374
+ }
375
+
376
+ .file-info {
377
+ display: none;
378
+ background: #f0f9ff;
379
+ border: 1px solid #bfdbfe;
380
+ border-radius: 8px;
381
+ padding: 12px 16px;
382
+ margin-top: 12px;
383
+ }
384
+
385
+ .file-info.show {
386
+ display: flex;
387
+ align-items: center;
388
+ gap: 12px;
389
+ }
390
+
391
+ .file-icon {
392
+ color: #3b82f6;
393
+ font-size: 20px;
394
+ }
395
+
396
+ .file-details {
397
+ flex: 1;
398
+ }
399
+
400
+ .file-name {
401
+ font-weight: 500;
402
+ color: #1f2937;
403
+ }
404
+
405
+ .file-size {
406
+ font-size: 14px;
407
+ color: #6b7280;
408
+ }
409
+
410
+ .file-remove {
411
+ background: none;
412
+ border: none;
413
+ color: #6b7280;
414
+ cursor: pointer;
415
+ font-size: 18px;
416
+ padding: 4px;
417
+ border-radius: 4px;
418
+ transition: all 0.2s ease;
419
+ }
420
+
421
+ .file-remove:hover {
422
+ background: #fee2e2;
423
+ color: #dc2626;
424
+ }
425
+
426
  .error {
427
  background: #fef2f2;
428
  color: #dc2626;
 
456
 
457
  <div class="content">
458
  <form id="piiForm">
459
+ <div class="form-group input-method-selection">
460
+ <label>Choose input method:</label>
461
+ <div class="input-tabs">
462
+ <button type="button" class="input-tab active" onclick="switchInputMethod('text')">
463
+ πŸ“ Text Input
464
+ </button>
465
+ <button type="button" class="input-tab" onclick="switchInputMethod('pdf')">
466
+ πŸ“„ PDF Upload
467
+ </button>
468
+ </div>
469
+
470
+ <div id="textInput" class="input-content active">
471
+ <textarea
472
+ id="inputText"
473
+ class="input-textarea"
474
+ placeholder="Enter text containing PII information...
475
 
476
  Example: Hi, my name is John Smith and my email is [email protected]. Call me at 555-1234."
477
+ ></textarea>
478
+ </div>
479
+
480
+ <div id="pdfInput" class="input-content">
481
+ <div class="dropzone" id="dropzone" onclick="document.getElementById('fileInput').click()">
482
+ <div class="dropzone-icon">πŸ“„</div>
483
+ <div class="dropzone-text">Drop your PDF here or click to browse</div>
484
+ <div class="dropzone-subtext">Supports PDF files up to 10MB</div>
485
+ </div>
486
+ <input type="file" id="fileInput" accept=".pdf" style="display: none;">
487
+ <div id="fileInfo" class="file-info">
488
+ <div class="file-icon">πŸ“„</div>
489
+ <div class="file-details">
490
+ <div class="file-name" id="fileName"></div>
491
+ <div class="file-size" id="fileSize"></div>
492
+ </div>
493
+ <button type="button" class="file-remove" onclick="removeFile()">βœ•</button>
494
+ </div>
495
+ </div>
496
  </div>
497
 
498
  <div class="form-group">
 
501
  <div class="method-option">
502
  <input type="radio" id="mistral" name="method" value="mistral" class="method-radio" checked onchange="toggleMistralModelSelection()">
503
  <label for="mistral" class="method-label">
504
+ <div class="method-title">Mistral AI</div>
505
  <div class="method-desc">High accuracy via API</div>
506
  </label>
507
  </div>
508
  <div class="method-option">
509
  <input type="radio" id="bert" name="method" value="bert" class="method-radio" onchange="toggleMistralModelSelection()">
510
  <label for="bert" class="method-label">
511
+ <div class="method-title">BERT</div>
512
  <div class="method-desc">Fast local processing</div>
513
  </label>
514
  </div>
 
521
  <div class="method-option">
522
  <input type="radio" id="base" name="model" value="base" class="method-radio" checked>
523
  <label for="base" class="method-label">
524
+ <div class="method-title">Base Model</div>
525
  <div class="method-desc">mistral-large-latest with detailed prompting</div>
526
  </label>
527
  </div>
528
  <div class="method-option">
529
  <input type="radio" id="finetuned" name="model" value="finetuned" class="method-radio">
530
  <label for="finetuned" class="method-label">
531
+ <div class="method-title">Fine-tuned Model</div>
532
  <div class="method-desc">Specialized PII detection model</div>
533
  </label>
534
  </div>
 
536
  </div>
537
 
538
  <div class="form-group pii-selection">
539
+ <label>Select PII entities to mask:</label>
540
  <div class="pii-controls">
541
  <button type="button" class="pii-btn primary" onclick="selectAllPII()">Select All</button>
542
  <button type="button" class="pii-btn" onclick="selectNonePII()">Select None</button>
 
548
  </div>
549
 
550
  <button type="submit" class="process-btn" id="processBtn">
551
+ Process Text
552
  </button>
553
  </form>
554
 
 
559
 
560
  <div class="result" id="result">
561
  <div class="result-section">
562
+ <div class="result-title">Masked Text</div>
563
  <div class="result-content" id="maskedText"></div>
564
  </div>
565
 
566
  <div class="result-section">
567
+ <div class="result-title">Detected Entities</div>
568
  <div class="result-content" id="entities"></div>
569
  </div>
570
 
 
614
  'ADDRESS', 'CITY', 'STATE', 'ZIPCODE', 'DOB', 'AGE', 'IP'
615
  ];
616
 
617
+ // Global variables
618
+ let selectedFile = null;
619
+ let currentInputMethod = 'text';
620
+
621
  // Initialize PII selection on page load
622
  document.addEventListener('DOMContentLoaded', function() {
623
  initializePIISelection();
624
+ initializeFileUpload();
625
  });
626
 
627
  function initializePIISelection() {
 
676
  return Array.from(checkboxes).map(cb => cb.value);
677
  }
678
 
679
+ function switchInputMethod(method) {
680
+ currentInputMethod = method;
681
+
682
+ // Update tab appearance
683
+ document.querySelectorAll('.input-tab').forEach(tab => tab.classList.remove('active'));
684
+ event.target.classList.add('active');
685
+
686
+ // Update content visibility
687
+ document.querySelectorAll('.input-content').forEach(content => content.classList.remove('active'));
688
+ document.getElementById(method + 'Input').classList.add('active');
689
+ }
690
+
691
+ function initializeFileUpload() {
692
+ const dropzone = document.getElementById('dropzone');
693
+ const fileInput = document.getElementById('fileInput');
694
+
695
+ // Prevent default drag behaviors
696
+ ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
697
+ dropzone.addEventListener(eventName, preventDefaults, false);
698
+ document.body.addEventListener(eventName, preventDefaults, false);
699
+ });
700
+
701
+ // Highlight drop area when item is dragged over it
702
+ ['dragenter', 'dragover'].forEach(eventName => {
703
+ dropzone.addEventListener(eventName, highlight, false);
704
+ });
705
+
706
+ ['dragleave', 'drop'].forEach(eventName => {
707
+ dropzone.addEventListener(eventName, unhighlight, false);
708
+ });
709
+
710
+ // Handle dropped files
711
+ dropzone.addEventListener('drop', handleDrop, false);
712
+
713
+ // Handle file input change
714
+ fileInput.addEventListener('change', function(e) {
715
+ handleFiles(e.target.files);
716
+ });
717
+
718
+ function preventDefaults(e) {
719
+ e.preventDefault();
720
+ e.stopPropagation();
721
+ }
722
+
723
+ function highlight(e) {
724
+ dropzone.classList.add('dragover');
725
+ }
726
+
727
+ function unhighlight(e) {
728
+ dropzone.classList.remove('dragover');
729
+ }
730
+
731
+ function handleDrop(e) {
732
+ const dt = e.dataTransfer;
733
+ const files = dt.files;
734
+ handleFiles(files);
735
+ }
736
+
737
+ function handleFiles(files) {
738
+ if (files.length > 0) {
739
+ const file = files[0];
740
+
741
+ // Validate file type
742
+ if (!file.name.toLowerCase().endsWith('.pdf')) {
743
+ showError('Please select a PDF file.');
744
+ return;
745
+ }
746
+
747
+ // Validate file size (10MB limit)
748
+ const maxSize = 10 * 1024 * 1024; // 10MB
749
+ if (file.size > maxSize) {
750
+ showError('File size must be less than 10MB.');
751
+ return;
752
+ }
753
+
754
+ selectedFile = file;
755
+ showFileInfo(file);
756
+ }
757
+ }
758
+ }
759
+
760
+ function showFileInfo(file) {
761
+ document.getElementById('fileName').textContent = file.name;
762
+ document.getElementById('fileSize').textContent = formatFileSize(file.size);
763
+ document.getElementById('fileInfo').classList.add('show');
764
+ }
765
+
766
+ function removeFile() {
767
+ selectedFile = null;
768
+ document.getElementById('fileInfo').classList.remove('show');
769
+ document.getElementById('fileInput').value = '';
770
+ }
771
+
772
+ function formatFileSize(bytes) {
773
+ if (bytes === 0) return '0 Bytes';
774
+ const k = 1024;
775
+ const sizes = ['Bytes', 'KB', 'MB', 'GB'];
776
+ const i = Math.floor(Math.log(bytes) / Math.log(k));
777
+ return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
778
+ }
779
+
780
  document.getElementById('piiForm').addEventListener('submit', async function(e) {
781
  e.preventDefault();
782
 
 
783
  const method = document.querySelector('input[name="method"]:checked').value;
784
  const model = document.querySelector('input[name="model"]:checked').value;
785
  const selectedPIIEntities = getSelectedPIIEntities();
786
 
 
 
 
 
 
787
  if (selectedPIIEntities.length === 0) {
788
  showError('Please select at least one PII entity to mask.');
789
  return;
 
795
  hideResult();
796
 
797
  try {
798
+ let response;
799
+
800
+ if (currentInputMethod === 'text') {
801
+ // Text input processing
802
+ const text = document.getElementById('inputText').value.trim();
803
+
804
+ if (!text) {
805
+ showError('Please enter some text to analyze.');
806
+ setLoading(false);
807
+ return;
808
+ }
809
+
810
+ response = await fetch('/predict', {
811
+ method: 'POST',
812
+ headers: {
813
+ 'Content-Type': 'application/json',
814
+ },
815
+ body: JSON.stringify({
816
+ text: text,
817
+ method: method,
818
+ model: model,
819
+ pii_entities: selectedPIIEntities
820
+ })
821
+ });
822
+
823
+ } else if (currentInputMethod === 'pdf') {
824
+ // PDF processing
825
+ if (!selectedFile) {
826
+ showError('Please select a PDF file to analyze.');
827
+ setLoading(false);
828
+ return;
829
+ }
830
+
831
+ const formData = new FormData();
832
+ formData.append('file', selectedFile);
833
+ formData.append('method', method);
834
+ formData.append('model', model);
835
+ formData.append('pii_entities', JSON.stringify(selectedPIIEntities));
836
+
837
+ response = await fetch('/predict-pdf', {
838
+ method: 'POST',
839
+ body: formData
840
+ });
841
+ }
842
 
843
  const result = await response.json();
844
 
 
862
  if (loading) {
863
  loadingEl.style.display = 'block';
864
  processBtn.disabled = true;
865
+ processBtn.textContent = 'Processing...';
866
  } else {
867
  loadingEl.style.display = 'none';
868
  processBtn.disabled = false;
869
+ processBtn.textContent = 'Process Text';
870
  }
871
  }
872
 
test_ocr.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for Mistral OCR functionality.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import asyncio
9
+ import logging
10
+ from pathlib import Path
11
+ from dotenv import load_dotenv
12
+
13
+ # Load environment variables from .env file
14
+ load_dotenv()
15
+
16
+ # Add the inference directory to Python path
17
+ sys.path.insert(0, str(Path(__file__).parent / "inference"))
18
+
19
+ from mistralai import Mistral
20
+
21
+ # Setup logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ async def test_mistral_ocr():
26
+ """Test Mistral OCR with a sample PDF."""
27
+
28
+ # Check if API key is available
29
+ api_key = os.environ.get("MISTRAL_API_KEY")
30
+ if not api_key:
31
+ logger.error("❌ MISTRAL_API_KEY not found in environment variables")
32
+ return
33
+
34
+ try:
35
+ # Initialize Mistral client
36
+ client = Mistral(api_key=api_key)
37
+ logger.info("βœ… Mistral client initialized")
38
+
39
+ # Test with a sample PDF from arXiv (Mistral paper)
40
+ test_pdf_url = "https://arxiv.org/pdf/2201.04234"
41
+
42
+ logger.info(f"πŸ” Testing OCR with PDF: {test_pdf_url}")
43
+
44
+ # Process the PDF with OCR
45
+ ocr_response = client.ocr.process(
46
+ model="mistral-ocr-latest",
47
+ document={
48
+ "type": "document_url",
49
+ "document_url": test_pdf_url
50
+ },
51
+ include_image_base64=False # Don't include images for testing
52
+ )
53
+
54
+ logger.info("βœ… OCR processing completed")
55
+
56
+ # Extract the text content
57
+ logger.info(f"πŸ“Š OCR Response structure:")
58
+ logger.info(f" - Type: {type(ocr_response)}")
59
+ logger.info(f" - Has pages: {hasattr(ocr_response, 'pages')}")
60
+ logger.info(f" - Has content: {hasattr(ocr_response, 'content')}")
61
+
62
+ extracted_text = ""
63
+
64
+ if hasattr(ocr_response, 'pages') and ocr_response.pages:
65
+ logger.info(f"πŸ“„ Found {len(ocr_response.pages)} pages")
66
+
67
+ # Extract text from all pages
68
+ for i, page in enumerate(ocr_response.pages):
69
+ logger.info(f"πŸ“ƒ Page {i+1} structure: {dir(page)}")
70
+
71
+ if hasattr(page, 'markdown') and page.markdown:
72
+ page_text = page.markdown
73
+ extracted_text += page_text + "\n\n"
74
+ logger.info(f"πŸ“ Page {i+1} text length: {len(page_text)} characters")
75
+ logger.info(f"πŸ“ Page {i+1} preview: {page_text[:200]}...")
76
+ elif hasattr(page, 'content'):
77
+ page_text = page.content
78
+ extracted_text += page_text + "\n\n"
79
+ logger.info(f"πŸ“ Page {i+1} text length: {len(page_text)} characters")
80
+ logger.info(f"πŸ“ Page {i+1} preview: {page_text[:200]}...")
81
+ elif hasattr(page, 'text'):
82
+ page_text = page.text
83
+ extracted_text += page_text + "\n\n"
84
+ logger.info(f"πŸ“ Page {i+1} text length: {len(page_text)} characters")
85
+ logger.info(f"πŸ“ Page {i+1} preview: {page_text[:200]}...")
86
+ else:
87
+ logger.info(f"⚠️ Page {i+1} attributes: {[attr for attr in dir(page) if not attr.startswith('_')]}")
88
+
89
+ if extracted_text:
90
+ logger.info(f"πŸ“„ Total extracted text length: {len(extracted_text)} characters")
91
+
92
+ # Test if we can use this text for PII detection
93
+ if len(extracted_text) > 50:
94
+ logger.info("βœ… OCR extraction successful - text is suitable for PII detection")
95
+
96
+ # Try to import and test PII detection
97
+ try:
98
+ from mistral_prompting import create_mistral_service
99
+
100
+ logger.info("πŸ” Testing PII detection on OCR text...")
101
+ service = await create_mistral_service()
102
+
103
+ # Use a small sample of the text for testing
104
+ sample_text = extracted_text[:500] # First 500 characters
105
+ prediction = await service.predict(sample_text)
106
+
107
+ logger.info(f"πŸ“Š PII detection results:")
108
+ logger.info(f" - Entities found: {len(prediction.entities)}")
109
+ logger.info(f" - Spans detected: {len(prediction.spans)}")
110
+ logger.info(f" - Masked text preview: {prediction.masked_text[:100]}...")
111
+
112
+ except Exception as e:
113
+ logger.warning(f"⚠️ PII detection test failed: {e}")
114
+ logger.info("πŸ’‘ OCR works, but PII detection needs API key setup")
115
+ else:
116
+ logger.warning("⚠️ Extracted text too short")
117
+ else:
118
+ logger.warning("⚠️ No text extracted from pages")
119
+
120
+ elif hasattr(ocr_response, 'content'):
121
+ extracted_text = ocr_response.content
122
+ logger.info(f"πŸ“„ Extracted text length: {len(extracted_text)} characters")
123
+ logger.info(f"πŸ“ First 200 characters: {extracted_text[:200]}...")
124
+ else:
125
+ logger.error("❌ No content found in OCR response")
126
+ logger.info(f"Available attributes: {[attr for attr in dir(ocr_response) if not attr.startswith('_')]}")
127
+
128
+ except Exception as e:
129
+ logger.error(f"❌ OCR test failed: {e}")
130
+ logger.info("πŸ’‘ Make sure MISTRAL_API_KEY is set correctly")
131
+
132
+ if __name__ == "__main__":
133
+ asyncio.run(test_mistral_ocr())
uv.lock ADDED
The diff for this file is too large to render. See raw diff