crop-diag-module / app /models /knowledge_graph.py
Sontranwakumo
init: move from github
88cc76c
raw
history blame
5.04 kB
import os
import sys
from fastapi import Depends
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.core.config import Settings, get_settings
from utils.data_mapping import DataMapping
from utils.extract_entity import extract_entities
from core.type import Node
from neo4j import GraphDatabase
from utils.constant import NEO4J_LABELS, NEO4J_RELATIONS
NEO4J_URI = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")
class Neo4jConnection:
def __init__(self):
"""Khởi tạo kết nối tới Neo4j"""
self.uri = NEO4J_URI
self.user = NEO4J_USER
self.password = NEO4J_PASSWORD
self.database = NEO4J_DATABASE
self.driver = GraphDatabase.driver(
self.uri,
auth=(self.user, self.password),
database=self.database
)
self.entity_types = []
self.relations = []
with self.driver.session() as session:
result = session.run("CALL db.info()")
self.database_info = result.single().data()
self.entity_types = NEO4J_LABELS
self.relations = NEO4J_RELATIONS
def get_database_info(self):
"""Trả về thông tin về database đang kết nối"""
return self.database_info
def close(self):
"""Đóng kết nối tới Neo4j"""
if self.driver is not None:
self.driver.close()
def execute_query(self, query, parameters=None):
"""Thực thi một truy vấn Cypher bất kỳ"""
with self.driver.session() as session:
result = session.run(query, parameters)
return [record for record in result]
class KnowledgeGraphUtils:
def get_disease_from_env_factors(self, crop_id: str, params: list[Node]):
envFactors = [param.id for param in params if param.label == "EnvironmentalFactor"]
query = f"""
MATCH (c:Crop {{id: "{crop_id}"}})
WITH c
MATCH (d:Disease)-[:AFFECTS]-(c)
OPTIONAL MATCH (ef:EnvironmentalFactor)-[:FAVORS]-(d)
WHERE ef.id IN {envFactors}
OPTIONAL MATCH (ef2:EnvironmentalFactor)-[:FAVORS]-(cause:Cause)-[:CAUSES|AFFECTS]-(d)
WHERE ef2.id IN {envFactors}
WITH d, COLLECT(DISTINCT ef.id) AS direct_env, COLLECT(DISTINCT ef2.id) AS indirect_env
WHERE SIZE(direct_env) > 0 OR SIZE(indirect_env) > 0
RETURN DISTINCT d, direct_env, indirect_env
"""
kg = Neo4jConnection()
result = kg.execute_query(query)
print(result)
final_result = []
for record in result:
record_dict = dict(record)
disease = Node.map_json_to_node(dict(record_dict["d"]), "Disease")
env_ids = list(record_dict["direct_env"]) + list(record_dict["indirect_env"])
print(env_ids)
score = 0
for env_id in env_ids:
for param in params:
if param.id == env_id:
score = max(score, param.score)
disease.score = score
final_result.append({
"disease": disease,
"env_ids": env_ids
})
final_result.sort(key=lambda x: x["disease"].score, reverse=True)
return final_result
def get_disease_from_symptoms(self, crop_id: str, params: list[Node]) -> list:
symptoms = [param.id for param in params if param.label == "Symptom"]
query = f"""
MATCH (c:Crop {{id: "{crop_id}"}})
WITH c
MATCH (d:Disease)-[:AFFECTS]-(c)
OPTIONAL MATCH (sym1:Symptom)-[:HAS_SYMPTOM]-(d)
WHERE sym1.id IN {symptoms}
OPTIONAL MATCH (sym2:Symptom)-[:HAS_SYMPTOM|LOCATED_ON]-(p:PlantPart)-[:CONTAINS]-(d)
WHERE sym2.id IN {symptoms}
WITH d, p, c, sym1, sym2, COLLECT(DISTINCT sym1.id) AS direct_env, COLLECT(DISTINCT sym2.id) AS indirect_env
WHERE SIZE(direct_env) > 0 OR SIZE(indirect_env) > 0
RETURN d, c, p, sym1, sym2
"""
kg = Neo4jConnection()
result = kg.execute_query(query)
final_result = []
for record in result:
record_dict = dict(record)
disease = Node.map_json_to_node(dict(record_dict["d"]), "Disease")
symptom_ids = list(record_dict["sym1"]) + list(record_dict["sym2"])
score = 0
for symptom_id in symptom_ids:
for param in params:
if param.id == symptom_id:
score = max(score, param.score)
disease.score = score
final_result.append({
"disease": disease,
"symptom_ids": symptom_ids
})
final_result.sort(key=lambda x: x["disease"].score, reverse=True)
return final_result