#!/usr/bin/env python3
"""
FastAPI application for PII Masking Demo - HuggingFace Space.
Simple version using only Mistral Prompting service.
"""
import os
import time
import logging
from contextlib import asynccontextmanager
from typing import Dict, Any, List
from fastapi import FastAPI, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from pydantic import BaseModel, Field
# Import our inference services
from inference.mistral_prompting import create_mistral_service, MistralPromptingService
from inference.bert_classif import create_bert_service, BERTInferenceService
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global service instances
mistral_base_service: MistralPromptingService = None
mistral_finetuned_service: MistralPromptingService = None
bert_service: BERTInferenceService = None
# Model configurations
MODELS = {
"base": "mistral-large-latest",
"finetuned": "ft:ministral-8b-latest:c6d4dfa8:20250831:pii-1e-4-200:57d93df9"
}
# BERT model path - HuggingFace Hub repository
BERT_MODEL_PATH = "SoelMgd/bert-pii-detection"
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
global mistral_base_service, mistral_finetuned_service, bert_service
# Startup
logger.info("🚀 Starting PII Masking Demo application...")
try:
# Initialize base Mistral service
logger.info("Initializing base Mistral service...")
mistral_base_service = await create_mistral_service(model_name=MODELS["base"])
logger.info("✅ Base Mistral service initialized successfully")
# Initialize fine-tuned Mistral service
logger.info("Initializing fine-tuned Mistral service...")
mistral_finetuned_service = await create_mistral_service(model_name=MODELS["finetuned"])
logger.info("✅ Fine-tuned Mistral service initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Mistral services: {e}")
# Don't raise exception - let app start but handle gracefully in endpoints
try:
# Initialize BERT service
logger.info("Initializing BERT service...")
bert_service = await create_bert_service(model_path=BERT_MODEL_PATH)
logger.info("✅ BERT service initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize BERT service: {e}")
# Don't raise exception - let app start but handle gracefully in endpoints
yield
# Shutdown
logger.info("🔄 Shutting down application...")
# Create FastAPI app
app = FastAPI(
title="🔒 PII Masking Demo",
description="Personal Identifiable Information masking using Mistral AI",
version="1.0.0",
lifespan=lifespan
)
# Request/Response models
class PredictionRequest(BaseModel):
text: str = Field(..., description="Text to analyze for PII", min_length=1, max_length=5000)
method: str = Field(default="mistral", description="Method to use: 'mistral' or 'bert'")
model: str = Field(default="base", description="Model to use: 'base' for mistral-large-latest or 'finetuned' for fine-tuned model (ignored for BERT)")
class PredictionResponse(BaseModel):
masked_text: str = Field(description="Text with PII entities masked")
entities: Dict[str, list[str]] = Field(description="Detected PII entities")
processing_time: float = Field(description="Processing time in seconds")
method_used: str = Field(description="Method used for prediction")
num_entities: int = Field(description="Total number of entities found")
class HealthResponse(BaseModel):
status: str
services: Dict[str, Any]
timestamp: float
# Helper function to get the appropriate service
def get_mistral_service(model: str) -> MistralPromptingService:
"""Get the appropriate Mistral service based on model type."""
if model == "base":
if mistral_base_service is None:
raise HTTPException(
status_code=503,
detail="Base Mistral service not available. Please check API key configuration."
)
return mistral_base_service
elif model == "finetuned":
if mistral_finetuned_service is None:
raise HTTPException(
status_code=503,
detail="Fine-tuned Mistral service not available. Please check API key configuration."
)
return mistral_finetuned_service
else:
raise HTTPException(
status_code=400,
detail=f"Model '{model}' not supported. Use 'base' or 'finetuned'."
)
# Mount static files for frontend
try:
app.mount("/static", StaticFiles(directory="static"), name="static")
except Exception as e:
logger.warning(f"Could not mount static files: {e}")
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the main HTML page."""
try:
return FileResponse("static/index.html")
except Exception:
# Fallback HTML if static files not available
return HTMLResponse("""
PII Masking Demo
🔒 PII Masking Demo
Enter text below to detect and mask Personal Identifiable Information:
Results:
Masked Text:
Entities Found:
Processing Time: s
""")
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Predict PII entities and return masked text.
Supports Mistral models (base and fine-tuned) and BERT.
"""
# Validate method
if request.method not in ["mistral", "bert"]:
raise HTTPException(
status_code=400,
detail=f"Method '{request.method}' not supported. Use 'mistral' or 'bert'."
)
start_time = time.time()
try:
if request.method == "mistral":
# Get the appropriate Mistral service
service = get_mistral_service(request.model)
model_type = "Fine-tuned" if request.model == "finetuned" else "Base"
logger.info(f"🔍 Processing text with {model_type} Mistral model: {request.text[:100]}...")
# Call Mistral service
prediction = await service.predict(request.text)
method_used = f"{request.method}-{request.model}"
elif request.method == "bert":
# Check BERT service availability
if bert_service is None:
raise HTTPException(
status_code=503,
detail="BERT service not available. Please check model configuration."
)
logger.info(f"🔍 Processing text with BERT model: {request.text[:100]}...")
# Call BERT service
prediction = await bert_service.predict(request.text)
method_used = "bert"
processing_time = time.time() - start_time
# Count total entities
num_entities = sum(len(entities) for entities in prediction.entities.values())
logger.info(f"✅ Prediction completed in {processing_time:.3f}s - found {num_entities} entities")
return PredictionResponse(
masked_text=prediction.masked_text,
entities=prediction.entities,
processing_time=processing_time,
method_used=method_used,
num_entities=num_entities
)
except Exception as e:
logger.error(f"❌ Prediction failed: {e}")
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
global mistral_base_service, mistral_finetuned_service, bert_service
services_status = {
"mistral_base": {
"available": mistral_base_service is not None,
"initialized": mistral_base_service.is_initialized if mistral_base_service else False,
"model": MODELS["base"],
"info": mistral_base_service.get_service_info() if mistral_base_service else None
},
"mistral_finetuned": {
"available": mistral_finetuned_service is not None,
"initialized": mistral_finetuned_service.is_initialized if mistral_finetuned_service else False,
"model": MODELS["finetuned"],
"info": mistral_finetuned_service.get_service_info() if mistral_finetuned_service else None
},
"bert": {
"available": bert_service is not None,
"initialized": bert_service.is_initialized if bert_service else False,
"model": BERT_MODEL_PATH,
"info": bert_service.get_service_info() if bert_service else None
}
}
# Overall status
base_healthy = mistral_base_service and mistral_base_service.is_initialized
finetuned_healthy = mistral_finetuned_service and mistral_finetuned_service.is_initialized
bert_healthy = bert_service and bert_service.is_initialized
healthy_services = sum([base_healthy, finetuned_healthy, bert_healthy])
if healthy_services == 3:
overall_status = "healthy"
elif healthy_services >= 1:
overall_status = "partial"
else:
overall_status = "degraded"
return HealthResponse(
status=overall_status,
services=services_status,
timestamp=time.time()
)
@app.get("/api/info")
async def api_info():
"""Get API information."""
return {
"name": "PII Masking Demo API",
"version": "1.0.0",
"description": "Personal Identifiable Information masking using Mistral AI",
"available_methods": ["mistral", "bert"],
"available_models": {
"base": {
"name": MODELS["base"],
"description": "Base Mistral model with detailed prompting"
},
"finetuned": {
"name": MODELS["finetuned"],
"description": "Fine-tuned Mistral model specialized for PII detection"
},
"bert": {
"name": BERT_MODEL_PATH,
"description": "BERT token classification model for fast PII detection"
}
},
"endpoints": {
"predict": "POST /predict - Analyze text for PII (supports 'model' parameter: 'base' or 'finetuned')",
"health": "GET /health - Health check",
"info": "GET /api/info - API information"
}
}
# Error handlers
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
return JSONResponse(
status_code=404,
content={"detail": f"Endpoint {request.url.path} not found"}
)
@app.exception_handler(500)
async def internal_error_handler(request: Request, exc):
logger.error(f"Internal server error: {exc}")
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"}
)
# Add CORS middleware for development
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify exact origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
import uvicorn
# For local development
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
reload=True,
log_level="info"
)