Spaces:
Sleeping
Sleeping
from typing import List | |
import torch.nn as nn | |
import torch | |
from torchvision import transforms | |
import clip | |
from PIL import Image | |
import os | |
from torchvision import models | |
from app.api.dto.kg_query import PredictedLabel | |
CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua', | |
'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau', | |
'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo', | |
'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc'] | |
CLASS_NAME_MAPPING = { | |
'cassava_mosaic': 'benhKhamLaSan', | |
'cassava_healthy': 'cassava healthy', | |
'cassava_green mite': 'benhNhenXanhSan', | |
'cassava_brown spot': 'benhDomNau', | |
'tomato_leaf blight': 'benhChayLaCaChua', | |
'tomato_verticulium wilt': 'benhVerticilliumWiltCaChua', | |
'tomato_leaf curl': 'benhXoanLaCaChua', | |
'tomato_leaf spot': 'benhDomLaCaChua', | |
'corn_leaf beetle': 'boCanhCungHaiLaNgo', | |
'corn_healthy': 'corn healthy', | |
'corn_leaf blight': 'benhChayLaNgo', | |
'corn_rust': 'benhRiSatNgo', | |
'corn_streak virus': 'benhSocLaNgo', | |
'corn_leaf spot': 'benhDomLaNgo', | |
'rice_brownspot': 'benhDomNauLuaNuoc', | |
'rice_rice blast': 'benhDaoOnLua', | |
'rice_bacterial blight': 'benhBacLaLua' | |
} | |
sorted_CLASS_NAMES = list(CLASS_NAME_MAPPING.values()) | |
CROP_IDS = ['san', 'san', 'san', 'san', 'caChua', 'caChua', 'caChua', 'caChua', | |
'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc'] | |
WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'best_finetuned_efficientnet_b0.pth') | |
class CLIPFineTuner(nn.Module): | |
def __init__(self, model, num_classes): | |
super(CLIPFineTuner, self).__init__() | |
self.model = model | |
self.classifier = nn.Linear(model.visual.output_dim, num_classes) | |
def forward(self, x): | |
with torch.no_grad(): | |
features = self.model.encode_image(x).float() # Convert to float32 | |
return self.classifier(features) | |
class CLIPModule: | |
def __init__(self): | |
model, preprocess = clip.load("ViT-B/32", jit=False) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = CLIPFineTuner(model, 17) | |
self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device)) | |
self.model.to(self.device) | |
self.model.eval() | |
self.classes = CLASS_NAMES | |
self.transform = preprocess | |
def predict_image(self, image: Image.Image): | |
output = self.__predict(image) | |
probabilities = torch.nn.functional.softmax(output, dim=1)[0] | |
predictions: List[PredictedLabel] = [] | |
for idx, prob in enumerate(probabilities): | |
predictions.append(PredictedLabel( | |
crop_id=CROP_IDS[idx], | |
label=self.classes[idx], | |
confidence=float(prob) | |
)) | |
# Sắp xếp giảm dần theo xác suất | |
predictions.sort(key=lambda x: x.confidence, reverse=True) | |
return predictions | |
def __predict(self, image_input): | |
""" | |
Dự đoán nhãn cho một ảnh. | |
Args: | |
image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image | |
device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu'). | |
Returns: | |
str: Nhãn dự đoán (e.g., "cassava_leaf beetle"). | |
""" | |
try: | |
image = self.__handle_image(image_input) | |
image_tensor = self.transform(image) | |
except ValueError as e: | |
raise e | |
except Exception as e: | |
raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}") | |
if image_tensor.dim() == 3: | |
image_tensor = image_tensor.unsqueeze(0) | |
print(image_tensor.shape) | |
image_tensor = image_tensor.to(self.device) | |
with torch.no_grad(): | |
output = self.model(image_tensor) | |
return output ## an array of 17 values, no softmax | |
def __handle_image(self, image_input): | |
if isinstance(image_input, str): | |
image = Image.open(image_input).convert('RGB') | |
elif isinstance(image_input, Image.Image): | |
image = image_input | |
else: | |
raise ValueError("Invalid image input") | |
return image | |
class EfficientNetModule: | |
def load_model(self, model_path, num_classes, device= "cpu"): | |
"""Load trained model""" | |
model = models.efficientnet_b0(num_classes=num_classes).to(device) | |
if os.path.exists(model_path): | |
checkpoint = torch.load(model_path, map_location=device) | |
# model.load_state_dict(checkpoint) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model_name = checkpoint.get('model_name', 'efficientnet_b0') | |
class_names = checkpoint.get('class_names', []) | |
print(f"✅ Model loaded from {model_path}") | |
print(f"📊 Model name: {model_name}") | |
print(f"📊 Class names: {class_names}") | |
class_names = [CLASS_NAME_MAPPING[name] for name in class_names] | |
print(f"📊 mapped Class names: {class_names}") | |
else: | |
raise FileNotFoundError(f"Model file not found: {model_path}") | |
model.eval() | |
return model, class_names | |
def __init__(self): | |
self.model, self.classes = self.load_model(WEIGHTS_PATH, 17) | |
self.transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
def __handle_image(self, image_input): | |
if isinstance(image_input, str): | |
image = Image.open(image_input).convert('RGB') | |
elif isinstance(image_input, Image.Image): | |
image = image_input | |
else: | |
raise ValueError("Invalid image input") | |
return image | |
def __predict(self, image_input): | |
image = self.__handle_image(image_input) | |
image_tensor = self.transform(image) | |
image_tensor = image_tensor.unsqueeze(0) | |
image_tensor = image_tensor.to("cpu") | |
with torch.no_grad(): | |
output = self.model(image_tensor) | |
return output | |
def predict_image(self, image: Image.Image): | |
output = self.__predict(image) | |
probabilities = torch.nn.functional.softmax(output, dim=1)[0] | |
predictions: List[PredictedLabel] = [] | |
for idx, prob in enumerate(probabilities): | |
predictions.append(PredictedLabel( | |
crop_id=CROP_IDS[idx], | |
label=self.classes[idx], | |
confidence=float(prob) | |
)) | |
predictions.sort(key=lambda x: x.confidence, reverse=True) | |
return predictions | |