import json import io from string import Template from fastapi import Depends, UploadFile import asyncio from PIL import Image import sqlite3 from app.api.dto.kg_query import KGQueryRequest, QueryContext, PredictedLabel from app.core.dependencies import get_all_models from app.core.type import Node from app.models.crop_clip import EfficientNetModule from app.models.gemini_caller import GeminiGenerator from app.models.knowledge_graph import KnowledgeGraphUtils from app.utils.constant import EXTRACTED_NODES from app.utils.data_mapping import VECTOR_EMBEDDINGS_DB_PATH, DataMapping from app.utils.extract_entity import clean_text, extract_entities from app.utils.prompt import EXTRACT_NODES_FROM_IMAGE_PROMPT, EXTRACT_NODES_FROM_TEXT_PROMPT, GET_STATEMENT_FROM_DISEASE_KG, GET_STATEMENT_FROM_ENV_FACTORS_KG class CustomJSONEncoder(json.JSONEncoder): def default(self, obj): if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel return obj.model_dump() elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel return obj.dict() elif isinstance(obj, (list, tuple)): return [self.default(item) if hasattr(item, 'model_dump') or hasattr(item, 'dict') else item for item in obj] return super().default(obj) def convert_to_json_serializable(obj): """Convert objects containing Node instances to JSON serializable format""" try: if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel return obj.model_dump() elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel return obj.dict() elif isinstance(obj, list): return [convert_to_json_serializable(item) for item in obj] elif isinstance(obj, dict): return {key: convert_to_json_serializable(value) for key, value in obj.items()} elif isinstance(obj, tuple): return [convert_to_json_serializable(item) for item in obj] elif obj is None: return None else: # Try to convert basic types try: json.dumps(obj) # Test if it's JSON serializable return obj except (TypeError, ValueError): # If it's not serializable, convert to string as fallback print(f"Warning: Converting non-serializable object {type(obj)} to string: {obj}") return str(obj) except Exception as e: print(f"Error in convert_to_json_serializable for object {type(obj)}: {e}") return str(obj) extracted_nodes = [ Node( id=node['id'], label=node['label'], name=node['name'], properties={'description': node['description']}, score=None ) for node in EXTRACTED_NODES ] class PredictService: def __init__(self, models): self.models = models async def predict_image(self, image: UploadFile): efficientnet_model: EfficientNetModule = self.models["efficientnet_model"] image_content = image.file.read() pil_image = Image.open(Image.io.BytesIO(image_content)).convert('RGB') return efficientnet_model.predict_image(pil_image) async def retrieve_kg(self, request: KGQueryRequest): try: kg: KnowledgeGraphUtils = self.models["knowledge_graph"] if not request.context: request.context = QueryContext() if request.crop_id: request.context.crop_id = request.crop_id if request.additional_info: additional_nodes = await self.__get_nodes_from_additional_info_async( request.additional_info, self.models["data_mapper"] ) if request.context.nodes is None: request.context.nodes = [] request.context.nodes = request.context.nodes + additional_nodes for node in request.context.nodes: if node.score is None: node.score = 0.9 env_task = asyncio.create_task( kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes) ) symptom_task = asyncio.create_task( kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes) ) env_results, symptom_results = await asyncio.gather(env_task, symptom_task) context = request.context context.nodes.extend([env_result["disease"] for env_result in env_results]) context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_results]) print(context.nodes) context.nodes.sort(key=lambda x: x.score, reverse=True) # Tính toán final_labels bằng trung bình có trọng số if context.predicted_labels: print("Got predicted labels") context.final_labels = self.calculate_final_labels( context.predicted_labels, env_results, symptom_results, context.crop_id ) return { "context": context, "env_results": env_results, "symptom_results": symptom_results } except Exception as e: print(e) raise e def calculate_final_labels(self, predicted_labels, env_result, symptom_result, crop_id): """ Tính toán final_labels bằng trung bình có trọng số từ: - predicted_labels: Kết quả từ CLIP model (weight: 0.4) - env_result: Kết quả từ environmental factors (weight: 0.3) - symptom_result: Kết quả từ symptoms (weight: 0.3) """ # Weight ENV_WEIGHT = 0.3 SYMPTOM_WEIGHT = 0.2 # Dictionary để tích lũy scores cho mỗi disease/crop combination label_scores = {} # 1. Điểm từ CLIP model for label in predicted_labels: key = f"{label.crop_id}_{label.label}" print(f"CLIP key: {key} score: {label.confidence}") if key not in label_scores: label_scores[key] = { "crop_id": label.crop_id, "label": label.label, "total_score": 0, "count": 0 } label_scores[key]["total_score"] += label.confidence label_scores[key]["count"] += 1 # 2. Điểm từ symptoms for symptom in symptom_result: disease = symptom.get("disease") if disease and hasattr(disease, 'score'): key = f"{crop_id}_{disease.id}" print(f"Symptom key: {key} score: {disease.score}") if key not in label_scores: label_scores[key] = { "crop_id": crop_id, "label": disease.id, "total_score": 0, "count": 0 } label_scores[key]["total_score"] += disease.score * SYMPTOM_WEIGHT * (1-label_scores[key]["total_score"]) # 3. Điểm từ environmental factors for env in env_result: disease = env.get("disease") if disease and hasattr(disease, 'score'): # Giả sử disease có thông tin về crop và label key = f"{crop_id}_{disease.id}" print(f"Env key: {key} score: {disease.score}") if key not in label_scores: label_scores[key] = { "crop_id": crop_id, "label": disease.id, "total_score": 0, "count": 0 } label_scores[key]["total_score"] += disease.score * ENV_WEIGHT * (1-label_scores[key]["total_score"]) # Tạo final_labels từ kết quả tính toán final_labels = [] for key, data in label_scores.items(): final_confidence = data["total_score"] final_labels.append(PredictedLabel( crop_id=data["crop_id"], label=data["label"], confidence=min(final_confidence, 1.0) # Đảm bảo không vượt quá 1.0 )) # Sắp xếp theo confidence giảm dần và lọc ngưỡng final_labels.sort(key=lambda x: x.confidence, reverse=True) print(final_labels) return [label for label in final_labels if label.confidence > 0.1] # Lọc ngưỡng thấp # TODO: async def get_nodes_from_image(self, image: UploadFile): try: gemini = GeminiGenerator() symptoms = self.models["data_mapper"].get_embedding_by_label("Symptom") symptom_list = [f"- id:{node.id} - name:{node.name}" for node in symptoms] symptom_list = "\n".join(symptom_list) prompt = Template(EXTRACT_NODES_FROM_IMAGE_PROMPT).substitute(symptom_list=symptom_list) image_content = image.file.read() pil_image = Image.open(io.BytesIO(image_content)).convert('RGB') ids = gemini.generate(prompt, image=pil_image) ids = (json.loads(clean_text(ids.text)))["ids"] print(ids) nodes = [] for id in ids: node = next((symptom for symptom in symptoms if symptom.id == id), None) nodes.append(node) return nodes except Exception as e: print(f"Error while extract knowledge entities from image: {str(e)}") return [] async def __get_nodes_from_additional_info_async(self, additional_info: str, data_mapper: DataMapping): entities = extract_entities(additional_info) if not entities: return [] tasks = [] for entity in entities: task = asyncio.create_task( data_mapper.get_top_result_by_text_async(entity.name, 3), name=f"query_entity_{entity.name}" ) tasks.append(task) results = await asyncio.gather(*tasks, return_exceptions=True) top_results: list[Node] = [] for i, result in enumerate(results): if isinstance(result, Exception): continue for node in result: top_results.append(node) return top_results def get_embedding_by_id_threadsafe(self, id): # Mỗi thread tạo connection riêng conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False) cursor = conn.cursor() try: cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,)) result = cursor.fetchone() return result finally: cursor.close() # Đóng connection sau khi dùng xong conn.close() async def retrieve_kg_text(self, request: KGQueryRequest): try: nodes = await self.get_nodes_from_text(request.additional_info) kg: KnowledgeGraphUtils = self.models["knowledge_graph"] env_task = asyncio.create_task( kg.get_disease_from_env_factors(request.crop_id, nodes) ) symptom_task = asyncio.create_task( kg.get_disease_from_symptoms(request.crop_id, nodes) ) env_results, symptom_results = await asyncio.gather(env_task, symptom_task) best_label = request.context.predicted_labels[0].label best_env_result = next((result for result in env_results if result["disease"].id == best_label), None) best_env_result_str = str(best_env_result) best_symptom_result = next((result for result in symptom_results if result["disease"].id == best_label), None) best_symptom_result_str = str(best_symptom_result) prompt1 = None prompt2 = None result1 = None result2 = None if best_env_result: prompt1 = Template(GET_STATEMENT_FROM_ENV_FACTORS_KG).substitute(context=best_env_result_str) if best_symptom_result: prompt2 = Template(GET_STATEMENT_FROM_DISEASE_KG).substitute(context=best_symptom_result_str) gemini = GeminiGenerator() print(prompt1) if prompt1: result1 = gemini.generate(prompt1) if prompt2: result2 = gemini.generate(prompt2) return { "env_results": env_results, "symptom_results": symptom_results, "env_statement": result1.text if result1 else None, "symptom_statement": result2.text if result2 else None } except Exception as e: print(e) raise e async def get_nodes_from_text(self, text: str): try: gemini = GeminiGenerator() node_list = [f" + id:{node.id}, name:{node.name}, description:{node.properties.get('description', '')}" for node in extracted_nodes] prompt = Template(EXTRACT_NODES_FROM_TEXT_PROMPT).substitute(text=text, node_list=node_list) ids = gemini.generate(prompt) print(ids) ids = (json.loads(clean_text(ids.text)))["ids"] print(ids) nodes = [next((node for node in extracted_nodes if node.id == id), None) for id in ids] return nodes except Exception as e: print(e) # async def get_all_nodes(self): # try: # kg: KnowledgeGraphUtils = self.models["knowledge_graph"] # list_nodes = await kg.get_all_nodes() # return [dict(node[0], **{"label": "Symptom"}) for node in list_nodes] # except Exception as e: # print(e) # return [] def get_predict_service(models = Depends(get_all_models)): return PredictService(models)