Spaces:
Sleeping
Sleeping
import json | |
import io | |
from string import Template | |
from fastapi import Depends, UploadFile | |
import asyncio | |
from PIL import Image | |
import sqlite3 | |
from app.api.dto.kg_query import KGQueryRequest, QueryContext, PredictedLabel | |
from app.core.dependencies import get_all_models | |
from app.core.type import Node | |
from app.models.crop_clip import EfficientNetModule | |
from app.models.gemini_caller import GeminiGenerator | |
from app.models.knowledge_graph import KnowledgeGraphUtils | |
from app.utils.constant import EXTRACTED_NODES | |
from app.utils.data_mapping import VECTOR_EMBEDDINGS_DB_PATH, DataMapping | |
from app.utils.extract_entity import clean_text, extract_entities | |
from app.utils.prompt import EXTRACT_NODES_FROM_IMAGE_PROMPT, EXTRACT_NODES_FROM_TEXT_PROMPT, GET_STATEMENT_FROM_DISEASE_KG, GET_STATEMENT_FROM_ENV_FACTORS_KG | |
class CustomJSONEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel | |
return obj.model_dump() | |
elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel | |
return obj.dict() | |
elif isinstance(obj, (list, tuple)): | |
return [self.default(item) if hasattr(item, 'model_dump') or hasattr(item, 'dict') else item for item in obj] | |
return super().default(obj) | |
def convert_to_json_serializable(obj): | |
"""Convert objects containing Node instances to JSON serializable format""" | |
try: | |
if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel | |
return obj.model_dump() | |
elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel | |
return obj.dict() | |
elif isinstance(obj, list): | |
return [convert_to_json_serializable(item) for item in obj] | |
elif isinstance(obj, dict): | |
return {key: convert_to_json_serializable(value) for key, value in obj.items()} | |
elif isinstance(obj, tuple): | |
return [convert_to_json_serializable(item) for item in obj] | |
elif obj is None: | |
return None | |
else: | |
# Try to convert basic types | |
try: | |
json.dumps(obj) # Test if it's JSON serializable | |
return obj | |
except (TypeError, ValueError): | |
# If it's not serializable, convert to string as fallback | |
print(f"Warning: Converting non-serializable object {type(obj)} to string: {obj}") | |
return str(obj) | |
except Exception as e: | |
print(f"Error in convert_to_json_serializable for object {type(obj)}: {e}") | |
return str(obj) | |
extracted_nodes = [ | |
Node( | |
id=node['id'], | |
label=node['label'], | |
name=node['name'], | |
properties={'description': node['description']}, | |
score=None | |
) for node in EXTRACTED_NODES | |
] | |
class PredictService: | |
def __init__(self, models): | |
self.models = models | |
async def predict_image(self, image: UploadFile): | |
efficientnet_model: EfficientNetModule = self.models["efficientnet_model"] | |
image_content = image.file.read() | |
pil_image = Image.open(Image.io.BytesIO(image_content)).convert('RGB') | |
return efficientnet_model.predict_image(pil_image) | |
async def retrieve_kg(self, request: KGQueryRequest): | |
try: | |
kg: KnowledgeGraphUtils = self.models["knowledge_graph"] | |
if not request.context: | |
request.context = QueryContext() | |
if request.crop_id: | |
request.context.crop_id = request.crop_id | |
if request.additional_info: | |
additional_nodes = await self.__get_nodes_from_additional_info_async( | |
request.additional_info, self.models["data_mapper"] | |
) | |
if request.context.nodes is None: | |
request.context.nodes = [] | |
request.context.nodes = request.context.nodes + additional_nodes | |
for node in request.context.nodes: | |
if node.score is None: | |
node.score = 0.9 | |
env_task = asyncio.create_task( | |
kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes) | |
) | |
symptom_task = asyncio.create_task( | |
kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes) | |
) | |
env_results, symptom_results = await asyncio.gather(env_task, symptom_task) | |
context = request.context | |
context.nodes.extend([env_result["disease"] for env_result in env_results]) | |
context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_results]) | |
print(context.nodes) | |
context.nodes.sort(key=lambda x: x.score, reverse=True) | |
# Tính toán final_labels bằng trung bình có trọng số | |
if context.predicted_labels: | |
print("Got predicted labels") | |
context.final_labels = self.calculate_final_labels( | |
context.predicted_labels, | |
env_results, | |
symptom_results, | |
context.crop_id | |
) | |
return { | |
"context": context, | |
"env_results": env_results, | |
"symptom_results": symptom_results | |
} | |
except Exception as e: | |
print(e) | |
raise e | |
def calculate_final_labels(self, predicted_labels, env_result, symptom_result, crop_id): | |
""" | |
Tính toán final_labels bằng trung bình có trọng số từ: | |
- predicted_labels: Kết quả từ CLIP model (weight: 0.4) | |
- env_result: Kết quả từ environmental factors (weight: 0.3) | |
- symptom_result: Kết quả từ symptoms (weight: 0.3) | |
""" | |
# Weight | |
ENV_WEIGHT = 0.3 | |
SYMPTOM_WEIGHT = 0.2 | |
# Dictionary để tích lũy scores cho mỗi disease/crop combination | |
label_scores = {} | |
# 1. Điểm từ CLIP model | |
for label in predicted_labels: | |
key = f"{label.crop_id}_{label.label}" | |
print(f"CLIP key: {key} score: {label.confidence}") | |
if key not in label_scores: | |
label_scores[key] = { | |
"crop_id": label.crop_id, | |
"label": label.label, | |
"total_score": 0, | |
"count": 0 | |
} | |
label_scores[key]["total_score"] += label.confidence | |
label_scores[key]["count"] += 1 | |
# 2. Điểm từ symptoms | |
for symptom in symptom_result: | |
disease = symptom.get("disease") | |
if disease and hasattr(disease, 'score'): | |
key = f"{crop_id}_{disease.id}" | |
print(f"Symptom key: {key} score: {disease.score}") | |
if key not in label_scores: | |
label_scores[key] = { | |
"crop_id": crop_id, | |
"label": disease.id, | |
"total_score": 0, | |
"count": 0 | |
} | |
label_scores[key]["total_score"] += disease.score * SYMPTOM_WEIGHT * (1-label_scores[key]["total_score"]) | |
# 3. Điểm từ environmental factors | |
for env in env_result: | |
disease = env.get("disease") | |
if disease and hasattr(disease, 'score'): | |
# Giả sử disease có thông tin về crop và label | |
key = f"{crop_id}_{disease.id}" | |
print(f"Env key: {key} score: {disease.score}") | |
if key not in label_scores: | |
label_scores[key] = { | |
"crop_id": crop_id, | |
"label": disease.id, | |
"total_score": 0, | |
"count": 0 | |
} | |
label_scores[key]["total_score"] += disease.score * ENV_WEIGHT * (1-label_scores[key]["total_score"]) | |
# Tạo final_labels từ kết quả tính toán | |
final_labels = [] | |
for key, data in label_scores.items(): | |
final_confidence = data["total_score"] | |
final_labels.append(PredictedLabel( | |
crop_id=data["crop_id"], | |
label=data["label"], | |
confidence=min(final_confidence, 1.0) # Đảm bảo không vượt quá 1.0 | |
)) | |
# Sắp xếp theo confidence giảm dần và lọc ngưỡng | |
final_labels.sort(key=lambda x: x.confidence, reverse=True) | |
print(final_labels) | |
return [label for label in final_labels if label.confidence > 0.1] # Lọc ngưỡng thấp | |
# TODO: | |
async def get_nodes_from_image(self, image: UploadFile): | |
try: | |
gemini = GeminiGenerator() | |
symptoms = self.models["data_mapper"].get_embedding_by_label("Symptom") | |
symptom_list = [f"- id:{node.id} - name:{node.name}" for node in symptoms] | |
symptom_list = "\n".join(symptom_list) | |
prompt = Template(EXTRACT_NODES_FROM_IMAGE_PROMPT).substitute(symptom_list=symptom_list) | |
image_content = image.file.read() | |
pil_image = Image.open(io.BytesIO(image_content)).convert('RGB') | |
ids = gemini.generate(prompt, image=pil_image) | |
ids = (json.loads(clean_text(ids.text)))["ids"] | |
print(ids) | |
nodes = [] | |
for id in ids: | |
node = next((symptom for symptom in symptoms if symptom.id == id), None) | |
nodes.append(node) | |
return nodes | |
except Exception as e: | |
print(f"Error while extract knowledge entities from image: {str(e)}") | |
return [] | |
async def __get_nodes_from_additional_info_async(self, additional_info: str, data_mapper: DataMapping): | |
entities = extract_entities(additional_info) | |
if not entities: | |
return [] | |
tasks = [] | |
for entity in entities: | |
task = asyncio.create_task( | |
data_mapper.get_top_result_by_text_async(entity.name, 3), | |
name=f"query_entity_{entity.name}" | |
) | |
tasks.append(task) | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
top_results: list[Node] = [] | |
for i, result in enumerate(results): | |
if isinstance(result, Exception): | |
continue | |
for node in result: | |
top_results.append(node) | |
return top_results | |
def get_embedding_by_id_threadsafe(self, id): | |
# Mỗi thread tạo connection riêng | |
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False) | |
cursor = conn.cursor() | |
try: | |
cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,)) | |
result = cursor.fetchone() | |
return result | |
finally: | |
cursor.close() # Đóng connection sau khi dùng xong | |
conn.close() | |
async def retrieve_kg_text(self, request: KGQueryRequest): | |
try: | |
nodes = await self.get_nodes_from_text(request.additional_info) | |
kg: KnowledgeGraphUtils = self.models["knowledge_graph"] | |
env_task = asyncio.create_task( | |
kg.get_disease_from_env_factors(request.crop_id, nodes) | |
) | |
symptom_task = asyncio.create_task( | |
kg.get_disease_from_symptoms(request.crop_id, nodes) | |
) | |
env_results, symptom_results = await asyncio.gather(env_task, symptom_task) | |
best_label = request.context.predicted_labels[0].label | |
best_env_result = next((result for result in env_results if result["disease"].id == best_label), None) | |
best_env_result_str = str(best_env_result) | |
best_symptom_result = next((result for result in symptom_results if result["disease"].id == best_label), None) | |
best_symptom_result_str = str(best_symptom_result) | |
prompt1 = None | |
prompt2 = None | |
result1 = None | |
result2 = None | |
if best_env_result: | |
prompt1 = Template(GET_STATEMENT_FROM_ENV_FACTORS_KG).substitute(context=best_env_result_str) | |
if best_symptom_result: | |
prompt2 = Template(GET_STATEMENT_FROM_DISEASE_KG).substitute(context=best_symptom_result_str) | |
gemini = GeminiGenerator() | |
print(prompt1) | |
if prompt1: | |
result1 = gemini.generate(prompt1) | |
if prompt2: | |
result2 = gemini.generate(prompt2) | |
return { | |
"env_results": env_results, | |
"symptom_results": symptom_results, | |
"env_statement": result1.text if result1 else None, | |
"symptom_statement": result2.text if result2 else None | |
} | |
except Exception as e: | |
print(e) | |
raise e | |
async def get_nodes_from_text(self, text: str): | |
try: | |
gemini = GeminiGenerator() | |
node_list = [f" + id:{node.id}, name:{node.name}, description:{node.properties.get('description', '')}" for node in extracted_nodes] | |
prompt = Template(EXTRACT_NODES_FROM_TEXT_PROMPT).substitute(text=text, node_list=node_list) | |
ids = gemini.generate(prompt) | |
print(ids) | |
ids = (json.loads(clean_text(ids.text)))["ids"] | |
print(ids) | |
nodes = [next((node for node in extracted_nodes if node.id == id), None) for id in ids] | |
return nodes | |
except Exception as e: | |
print(e) | |
# async def get_all_nodes(self): | |
# try: | |
# kg: KnowledgeGraphUtils = self.models["knowledge_graph"] | |
# list_nodes = await kg.get_all_nodes() | |
# return [dict(node[0], **{"label": "Symptom"}) for node in list_nodes] | |
# except Exception as e: | |
# print(e) | |
# return [] | |
def get_predict_service(models = Depends(get_all_models)): | |
return PredictService(models) | |