Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import cv2 | |
import torch | |
from icevision.models.checkpoint import model_from_checkpoint | |
from classifier import CustomViT | |
def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels): | |
colors = { | |
0: (255,255,0), | |
1: (255, 0, 0), | |
2: (0, 0, 255), | |
3: (0,128,0), | |
4: (255,165,0), | |
5: (230,230,250), | |
6: (192,192,192) | |
} | |
texts = { | |
0: 'plastic', | |
1: 'dangerous', | |
2: 'carton', | |
3: 'glass', | |
4: 'organic', | |
5: 'rest', | |
6: 'other' | |
} | |
# Show image | |
boxes = boxes.cpu().detach().numpy().astype(np.int32) | |
fig, ax = plt.subplots(1, 1, figsize=(12, 6)) | |
for i, box in enumerate(boxes): | |
color = colors[labels[i]] | |
[x1, y1, x2, y2] = np.array(box).astype(int) | |
# Si no se hace la copia da error en cv2.rectangle | |
image = np.array(image).copy() | |
pt1 = (x1, y1) | |
pt2 = (x2, y2) | |
cv2.rectangle(image, pt1, pt2, color, thickness=5) | |
cv2.putText(image, texts[labels[i]], (x1, y1-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 4, thickness=5, color=color) | |
plt.axis('off') | |
ax.imshow(image) | |
fig.savefig("img.png", bbox_inches='tight') | |
def get_models( | |
detection_ckpt : str, | |
classifier_ckpt : str | |
) -> Tuple[torch.nn.Module, torch.nn.Module]: | |
""" | |
Get the detection and classifier models | |
Args: | |
detection_ckpt (str): Detection model checkpoint | |
classifier_ckpt (str): Classifier model checkpoint | |
Returns: | |
tuple: Tuple containing: | |
- (torch.nn.Module): Detection model | |
- (torch.nn.Module): Classifier model | |
""" | |
print('Loading the detection model') | |
checkpoint_and_model = model_from_checkpoint( | |
detection_ckpt, | |
model_name='ross.efficientdet', | |
backbone_name='d0', | |
img_size=512, | |
classes=['Waste'], | |
revise_keys=[(r'^model\.', '')], | |
map_location='cpu') | |
det_model = checkpoint_and_model['model'] | |
det_model.eval() | |
print('Loading the classifier model') | |
classifier = CustomViT(target_size=7, pretrained=False) | |
classifier.load_state_dict(torch.load(classifier_ckpt, map_location='cpu')) | |
classifier.eval() | |
return det_model, classifier |