crop-diag-module / app /models /crop_clip.py
Sontranwakumo
feat: apply new model
e57a125
raw
history blame
6.96 kB
from typing import List
import torch.nn as nn
import torch
from torchvision import transforms
import clip
from PIL import Image
import os
from torchvision import models
from app.api.dto.kg_query import PredictedLabel
CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua',
'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau',
'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo',
'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc']
CLASS_NAME_MAPPING = {
'cassava_mosaic': 'benhKhamLaSan',
'cassava_healthy': 'cassava healthy',
'cassava_green mite': 'benhNhenXanhSan',
'cassava_brown spot': 'benhDomNau',
'tomato_leaf blight': 'benhChayLaCaChua',
'tomato_verticulium wilt': 'benhVerticilliumWiltCaChua',
'tomato_leaf curl': 'benhXoanLaCaChua',
'tomato_leaf spot': 'benhDomLaCaChua',
'corn_leaf beetle': 'boCanhCungHaiLaNgo',
'corn_healthy': 'corn healthy',
'corn_leaf blight': 'benhChayLaNgo',
'corn_rust': 'benhRiSatNgo',
'corn_streak virus': 'benhSocLaNgo',
'corn_leaf spot': 'benhDomLaNgo',
'rice_brownspot': 'benhDomNauLuaNuoc',
'rice_rice blast': 'benhDaoOnLua',
'rice_bacterial blight': 'benhBacLaLua'
}
sorted_CLASS_NAMES = list(CLASS_NAME_MAPPING.values())
CROP_IDS = ['san', 'san', 'san', 'san', 'caChua', 'caChua', 'caChua', 'caChua',
'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc']
WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'best_finetuned_efficientnet_b0.pth')
class CLIPFineTuner(nn.Module):
def __init__(self, model, num_classes):
super(CLIPFineTuner, self).__init__()
self.model = model
self.classifier = nn.Linear(model.visual.output_dim, num_classes)
def forward(self, x):
with torch.no_grad():
features = self.model.encode_image(x).float() # Convert to float32
return self.classifier(features)
class CLIPModule:
def __init__(self):
model, preprocess = clip.load("ViT-B/32", jit=False)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = CLIPFineTuner(model, 17)
self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device))
self.model.to(self.device)
self.model.eval()
self.classes = CLASS_NAMES
self.transform = preprocess
def predict_image(self, image: Image.Image):
output = self.__predict(image)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
predictions: List[PredictedLabel] = []
for idx, prob in enumerate(probabilities):
predictions.append(PredictedLabel(
crop_id=CROP_IDS[idx],
label=self.classes[idx],
confidence=float(prob)
))
# Sắp xếp giảm dần theo xác suất
predictions.sort(key=lambda x: x.confidence, reverse=True)
return predictions
def __predict(self, image_input):
"""
Dự đoán nhãn cho một ảnh.
Args:
image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image
device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu').
Returns:
str: Nhãn dự đoán (e.g., "cassava_leaf beetle").
"""
try:
image = self.__handle_image(image_input)
image_tensor = self.transform(image)
except ValueError as e:
raise e
except Exception as e:
raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}")
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
print(image_tensor.shape)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
output = self.model(image_tensor)
return output ## an array of 17 values, no softmax
def __handle_image(self, image_input):
if isinstance(image_input, str):
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("Invalid image input")
return image
class EfficientNetModule:
def load_model(self, model_path, num_classes, device= "cpu"):
"""Load trained model"""
model = models.efficientnet_b0(num_classes=num_classes).to(device)
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
# model.load_state_dict(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
model_name = checkpoint.get('model_name', 'efficientnet_b0')
class_names = checkpoint.get('class_names', [])
print(f"✅ Model loaded from {model_path}")
print(f"📊 Model name: {model_name}")
print(f"📊 Class names: {class_names}")
class_names = [CLASS_NAME_MAPPING[name] for name in class_names]
print(f"📊 mapped Class names: {class_names}")
else:
raise FileNotFoundError(f"Model file not found: {model_path}")
model.eval()
return model, class_names
def __init__(self):
self.model, self.classes = self.load_model(WEIGHTS_PATH, 17)
self.transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __handle_image(self, image_input):
if isinstance(image_input, str):
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("Invalid image input")
return image
def __predict(self, image_input):
image = self.__handle_image(image_input)
image_tensor = self.transform(image)
image_tensor = image_tensor.unsqueeze(0)
image_tensor = image_tensor.to("cpu")
with torch.no_grad():
output = self.model(image_tensor)
return output
def predict_image(self, image: Image.Image):
output = self.__predict(image)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
predictions: List[PredictedLabel] = []
for idx, prob in enumerate(probabilities):
predictions.append(PredictedLabel(
crop_id=CROP_IDS[idx],
label=self.classes[idx],
confidence=float(prob)
))
predictions.sort(key=lambda x: x.confidence, reverse=True)
return predictions