File size: 2,963 Bytes
88cc76c
 
 
 
 
e57a125
88cc76c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e57a125
 
 
 
88cc76c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from app.core.config import get_settings
from app.api.routes import router as api_router
from app.models.crop_clip import EfficientNetModule
from app.models.gemini_caller import GeminiGenerator
from app.utils.data_mapping import DataMapping, SingletonModel
from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection
import asyncio
from concurrent.futures import ThreadPoolExecutor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
settings = get_settings()

class ModelLoader:
    def __init__(self):
        self.clip_model = None
        self.gemini_model = None
        self.sentence_transformer = None
        self.neo4j_connection = None

    def load_models(self):
        try:
            if settings.load_efficientnet_model:
                logger.info("Loading EfficientNet model...")
                self.efficientnet_model = EfficientNetModule()
                logger.info("EfficientNet model loaded successfully")

            if settings.load_gemini_model:
                logger.info("Loading Gemini model...")
                self.gemini_model = GeminiGenerator()
                logger.info("Gemini model loaded successfully")

            if settings.load_data_mapper:
                logger.info("Loading DataMapper model...")
                self.data_mapper = DataMapping()
                logger.info("DataMapper model loaded successfully")

            if settings.load_knowledge_graph:
                logger.info("Connecting to Knowledge Graph...")
                self.knowledge_graph = KnowledgeGraphUtils()
                logger.info("Knowledge Graph connection established")
        except Exception as e:
            logger.error(f"Failed to load models: {e}")
            raise

    def close(self):
        if self.neo4j_connection:
            logger.info("Closing Neo4j connection...")
            self.neo4j_connection.close()
        self.clip_model = None
        self.gemini_model = None
        self.sentence_transformer = None
        logger.info("Models released")

# Lifespan event handler
@asynccontextmanager
async def lifespan(app: FastAPI):
    loop = asyncio.get_event_loop()
    with ThreadPoolExecutor() as pool:
        await loop.run_in_executor(pool, app.state.model_loader.load_models)
    logger.info("Application startup complete")
    yield
    app.state.model_loader.close()
    logger.info("Application shutdown complete")

app = FastAPI(
    title="Crop Diagnosis Knowledge Graph API",
    description="API for querying crop diagnosis knowledge graph using LangChain",
    version="1.0.0",
    debug=settings.debug,
    lifespan=lifespan
)

app.state.model_loader = ModelLoader()

app.include_router(api_router, prefix="/api")

@app.get("/")
async def root():
    return {"message": "Welcome to Crop Diagnosis Knowledge Graph API"}