Twin commited on
Commit
1ac13f0
Β·
1 Parent(s): 92539a7

Complete BERT integration with frontend support

Browse files

✨ Features:
- Add BERT option in frontend method selection
- Add conditional Mistral model selection (hidden when BERT selected)
- Complete BERT integration in FastAPI app
- Update health check to include BERT service
- Update API info with BERT support

πŸ”§ Technical:
- BERT service initialization in lifespan
- BERT prediction route support
- Frontend JavaScript for method switching
- Updated request/response models

Files changed (2) hide show
  1. app.py +58 -20
  2. static/index.html +20 -2
app.py CHANGED
@@ -44,7 +44,7 @@ BERT_MODEL_PATH = "SoelMgd/bert-pii-detection"
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
46
  """Manage application lifespan - startup and shutdown."""
47
- global mistral_base_service, mistral_finetuned_service
48
 
49
  # Startup
50
  logger.info("πŸš€ Starting PII Masking Demo application...")
@@ -61,7 +61,17 @@ async def lifespan(app: FastAPI):
61
  logger.info("βœ… Fine-tuned Mistral service initialized successfully")
62
 
63
  except Exception as e:
64
- logger.error(f"Failed to initialize services: {e}")
 
 
 
 
 
 
 
 
 
 
65
  # Don't raise exception - let app start but handle gracefully in endpoints
66
 
67
  yield
@@ -80,8 +90,8 @@ app = FastAPI(
80
  # Request/Response models
81
  class PredictionRequest(BaseModel):
82
  text: str = Field(..., description="Text to analyze for PII", min_length=1, max_length=5000)
83
- method: str = Field(default="mistral", description="Method to use (currently only 'mistral')")
84
- model: str = Field(default="base", description="Model to use: 'base' for mistral-large-latest or 'finetuned' for fine-tuned model")
85
 
86
  class PredictionResponse(BaseModel):
87
  masked_text: str = Field(description="Text with PII entities masked")
@@ -200,26 +210,41 @@ async def predict(request: PredictionRequest):
200
  """
201
  Predict PII entities and return masked text.
202
 
203
- Supports both base and fine-tuned Mistral models.
204
  """
205
  # Validate method
206
- if request.method != "mistral":
207
  raise HTTPException(
208
  status_code=400,
209
- detail=f"Method '{request.method}' not supported. Currently only 'mistral' is available."
210
  )
211
 
212
- # Get the appropriate service
213
- service = get_mistral_service(request.model)
214
-
215
  start_time = time.time()
216
 
217
  try:
218
- model_type = "Fine-tuned" if request.model == "finetuned" else "Base"
219
- logger.info(f"πŸ” Processing text with {model_type} Mistral model: {request.text[:100]}...")
220
-
221
- # Call Mistral service
222
- prediction = await service.predict(request.text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  processing_time = time.time() - start_time
225
 
@@ -232,7 +257,7 @@ async def predict(request: PredictionRequest):
232
  masked_text=prediction.masked_text,
233
  entities=prediction.entities,
234
  processing_time=processing_time,
235
- method_used=f"{request.method}-{request.model}",
236
  num_entities=num_entities
237
  )
238
 
@@ -246,7 +271,7 @@ async def predict(request: PredictionRequest):
246
  @app.get("/health", response_model=HealthResponse)
247
  async def health_check():
248
  """Health check endpoint."""
249
- global mistral_base_service, mistral_finetuned_service
250
 
251
  services_status = {
252
  "mistral_base": {
@@ -260,16 +285,25 @@ async def health_check():
260
  "initialized": mistral_finetuned_service.is_initialized if mistral_finetuned_service else False,
261
  "model": MODELS["finetuned"],
262
  "info": mistral_finetuned_service.get_service_info() if mistral_finetuned_service else None
 
 
 
 
 
 
263
  }
264
  }
265
 
266
  # Overall status
267
  base_healthy = mistral_base_service and mistral_base_service.is_initialized
268
  finetuned_healthy = mistral_finetuned_service and mistral_finetuned_service.is_initialized
 
 
 
269
 
270
- if base_healthy and finetuned_healthy:
271
  overall_status = "healthy"
272
- elif base_healthy or finetuned_healthy:
273
  overall_status = "partial"
274
  else:
275
  overall_status = "degraded"
@@ -287,7 +321,7 @@ async def api_info():
287
  "name": "PII Masking Demo API",
288
  "version": "1.0.0",
289
  "description": "Personal Identifiable Information masking using Mistral AI",
290
- "available_methods": ["mistral"],
291
  "available_models": {
292
  "base": {
293
  "name": MODELS["base"],
@@ -296,6 +330,10 @@ async def api_info():
296
  "finetuned": {
297
  "name": MODELS["finetuned"],
298
  "description": "Fine-tuned Mistral model specialized for PII detection"
 
 
 
 
299
  }
300
  },
301
  "endpoints": {
 
44
  @asynccontextmanager
45
  async def lifespan(app: FastAPI):
46
  """Manage application lifespan - startup and shutdown."""
47
+ global mistral_base_service, mistral_finetuned_service, bert_service
48
 
49
  # Startup
50
  logger.info("πŸš€ Starting PII Masking Demo application...")
 
61
  logger.info("βœ… Fine-tuned Mistral service initialized successfully")
62
 
63
  except Exception as e:
64
+ logger.error(f"Failed to initialize Mistral services: {e}")
65
+ # Don't raise exception - let app start but handle gracefully in endpoints
66
+
67
+ try:
68
+ # Initialize BERT service
69
+ logger.info("Initializing BERT service...")
70
+ bert_service = await create_bert_service(model_path=BERT_MODEL_PATH)
71
+ logger.info("βœ… BERT service initialized successfully")
72
+
73
+ except Exception as e:
74
+ logger.error(f"Failed to initialize BERT service: {e}")
75
  # Don't raise exception - let app start but handle gracefully in endpoints
76
 
77
  yield
 
90
  # Request/Response models
91
  class PredictionRequest(BaseModel):
92
  text: str = Field(..., description="Text to analyze for PII", min_length=1, max_length=5000)
93
+ method: str = Field(default="mistral", description="Method to use: 'mistral' or 'bert'")
94
+ model: str = Field(default="base", description="Model to use: 'base' for mistral-large-latest or 'finetuned' for fine-tuned model (ignored for BERT)")
95
 
96
  class PredictionResponse(BaseModel):
97
  masked_text: str = Field(description="Text with PII entities masked")
 
210
  """
211
  Predict PII entities and return masked text.
212
 
213
+ Supports Mistral models (base and fine-tuned) and BERT.
214
  """
215
  # Validate method
216
+ if request.method not in ["mistral", "bert"]:
217
  raise HTTPException(
218
  status_code=400,
219
+ detail=f"Method '{request.method}' not supported. Use 'mistral' or 'bert'."
220
  )
221
 
 
 
 
222
  start_time = time.time()
223
 
224
  try:
225
+ if request.method == "mistral":
226
+ # Get the appropriate Mistral service
227
+ service = get_mistral_service(request.model)
228
+ model_type = "Fine-tuned" if request.model == "finetuned" else "Base"
229
+ logger.info(f"πŸ” Processing text with {model_type} Mistral model: {request.text[:100]}...")
230
+
231
+ # Call Mistral service
232
+ prediction = await service.predict(request.text)
233
+ method_used = f"{request.method}-{request.model}"
234
+
235
+ elif request.method == "bert":
236
+ # Check BERT service availability
237
+ if bert_service is None:
238
+ raise HTTPException(
239
+ status_code=503,
240
+ detail="BERT service not available. Please check model configuration."
241
+ )
242
+
243
+ logger.info(f"πŸ” Processing text with BERT model: {request.text[:100]}...")
244
+
245
+ # Call BERT service
246
+ prediction = await bert_service.predict(request.text)
247
+ method_used = "bert"
248
 
249
  processing_time = time.time() - start_time
250
 
 
257
  masked_text=prediction.masked_text,
258
  entities=prediction.entities,
259
  processing_time=processing_time,
260
+ method_used=method_used,
261
  num_entities=num_entities
262
  )
263
 
 
271
  @app.get("/health", response_model=HealthResponse)
272
  async def health_check():
273
  """Health check endpoint."""
274
+ global mistral_base_service, mistral_finetuned_service, bert_service
275
 
276
  services_status = {
277
  "mistral_base": {
 
285
  "initialized": mistral_finetuned_service.is_initialized if mistral_finetuned_service else False,
286
  "model": MODELS["finetuned"],
287
  "info": mistral_finetuned_service.get_service_info() if mistral_finetuned_service else None
288
+ },
289
+ "bert": {
290
+ "available": bert_service is not None,
291
+ "initialized": bert_service.is_initialized if bert_service else False,
292
+ "model": BERT_MODEL_PATH,
293
+ "info": bert_service.get_service_info() if bert_service else None
294
  }
295
  }
296
 
297
  # Overall status
298
  base_healthy = mistral_base_service and mistral_base_service.is_initialized
299
  finetuned_healthy = mistral_finetuned_service and mistral_finetuned_service.is_initialized
300
+ bert_healthy = bert_service and bert_service.is_initialized
301
+
302
+ healthy_services = sum([base_healthy, finetuned_healthy, bert_healthy])
303
 
304
+ if healthy_services == 3:
305
  overall_status = "healthy"
306
+ elif healthy_services >= 1:
307
  overall_status = "partial"
308
  else:
309
  overall_status = "degraded"
 
321
  "name": "PII Masking Demo API",
322
  "version": "1.0.0",
323
  "description": "Personal Identifiable Information masking using Mistral AI",
324
+ "available_methods": ["mistral", "bert"],
325
  "available_models": {
326
  "base": {
327
  "name": MODELS["base"],
 
330
  "finetuned": {
331
  "name": MODELS["finetuned"],
332
  "description": "Fine-tuned Mistral model specialized for PII detection"
333
+ },
334
+ "bert": {
335
+ "name": BERT_MODEL_PATH,
336
+ "description": "BERT token classification model for fast PII detection"
337
  }
338
  },
339
  "endpoints": {
static/index.html CHANGED
@@ -280,16 +280,23 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
280
  <label>Select masking method:</label>
281
  <div class="method-selection">
282
  <div class="method-option">
283
- <input type="radio" id="mistral" name="method" value="mistral" class="method-radio" checked>
284
  <label for="mistral" class="method-label">
285
  <div class="method-title">🧠 Mistral AI</div>
286
  <div class="method-desc">High accuracy via API</div>
287
  </label>
288
  </div>
 
 
 
 
 
 
 
289
  </div>
290
  </div>
291
 
292
- <div class="form-group">
293
  <label>Select Mistral model:</label>
294
  <div class="method-selection">
295
  <div class="method-option">
@@ -435,6 +442,17 @@ Example: Hi, my name is John Smith and my email is [email protected]. Call
435
  function hideError() {
436
  document.getElementById('error').style.display = 'none';
437
  }
 
 
 
 
 
 
 
 
 
 
 
438
  </script>
439
  </body>
440
  </html>
 
280
  <label>Select masking method:</label>
281
  <div class="method-selection">
282
  <div class="method-option">
283
+ <input type="radio" id="mistral" name="method" value="mistral" class="method-radio" checked onchange="toggleMistralModelSelection()">
284
  <label for="mistral" class="method-label">
285
  <div class="method-title">🧠 Mistral AI</div>
286
  <div class="method-desc">High accuracy via API</div>
287
  </label>
288
  </div>
289
+ <div class="method-option">
290
+ <input type="radio" id="bert" name="method" value="bert" class="method-radio" onchange="toggleMistralModelSelection()">
291
+ <label for="bert" class="method-label">
292
+ <div class="method-title">πŸ€– BERT</div>
293
+ <div class="method-desc">Fast local processing</div>
294
+ </label>
295
+ </div>
296
  </div>
297
  </div>
298
 
299
+ <div class="form-group" id="mistralModelSelection">
300
  <label>Select Mistral model:</label>
301
  <div class="method-selection">
302
  <div class="method-option">
 
442
  function hideError() {
443
  document.getElementById('error').style.display = 'none';
444
  }
445
+
446
+ function toggleMistralModelSelection() {
447
+ const mistralSelected = document.getElementById('mistral').checked;
448
+ const mistralModelSelection = document.getElementById('mistralModelSelection');
449
+
450
+ if (mistralSelected) {
451
+ mistralModelSelection.style.display = 'block';
452
+ } else {
453
+ mistralModelSelection.style.display = 'none';
454
+ }
455
+ }
456
  </script>
457
  </body>
458
  </html>