Sontranwakumo
feat: apply new model
e57a125
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from app.core.config import get_settings
from app.api.routes import router as api_router
from app.models.crop_clip import EfficientNetModule
from app.models.gemini_caller import GeminiGenerator
from app.utils.data_mapping import DataMapping, SingletonModel
from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection
import asyncio
from concurrent.futures import ThreadPoolExecutor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
settings = get_settings()
class ModelLoader:
def __init__(self):
self.clip_model = None
self.gemini_model = None
self.sentence_transformer = None
self.neo4j_connection = None
def load_models(self):
try:
if settings.load_efficientnet_model:
logger.info("Loading EfficientNet model...")
self.efficientnet_model = EfficientNetModule()
logger.info("EfficientNet model loaded successfully")
if settings.load_gemini_model:
logger.info("Loading Gemini model...")
self.gemini_model = GeminiGenerator()
logger.info("Gemini model loaded successfully")
if settings.load_data_mapper:
logger.info("Loading DataMapper model...")
self.data_mapper = DataMapping()
logger.info("DataMapper model loaded successfully")
if settings.load_knowledge_graph:
logger.info("Connecting to Knowledge Graph...")
self.knowledge_graph = KnowledgeGraphUtils()
logger.info("Knowledge Graph connection established")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise
def close(self):
if self.neo4j_connection:
logger.info("Closing Neo4j connection...")
self.neo4j_connection.close()
self.clip_model = None
self.gemini_model = None
self.sentence_transformer = None
logger.info("Models released")
# Lifespan event handler
@asynccontextmanager
async def lifespan(app: FastAPI):
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as pool:
await loop.run_in_executor(pool, app.state.model_loader.load_models)
logger.info("Application startup complete")
yield
app.state.model_loader.close()
logger.info("Application shutdown complete")
app = FastAPI(
title="Crop Diagnosis Knowledge Graph API",
description="API for querying crop diagnosis knowledge graph using LangChain",
version="1.0.0",
debug=settings.debug,
lifespan=lifespan
)
app.state.model_loader = ModelLoader()
app.include_router(api_router, prefix="/api")
@app.get("/")
async def root():
return {"message": "Welcome to Crop Diagnosis Knowledge Graph API"}