Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- cam.py +80 -0
- glaucoma.py +51 -0
cam.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from typing import List, Callable, Optional
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
from pytorch_grad_cam import GradCAM
|
| 11 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 12 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
""" Model wrapper to return a tensor"""
|
| 16 |
+
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
|
| 17 |
+
def __init__(self, model):
|
| 18 |
+
super(HuggingfaceToTensorModelWrapper, self).__init__()
|
| 19 |
+
self.model = model
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return self.model(x).logits
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ClassActivationMap(object):
|
| 26 |
+
def __init__(self, model, processor):
|
| 27 |
+
self.model = HuggingfaceToTensorModelWrapper(model)
|
| 28 |
+
target_layer = model.swinv2.layernorm
|
| 29 |
+
self.target_layer = [target_layer]
|
| 30 |
+
self.processor = processor
|
| 31 |
+
|
| 32 |
+
def swinT_reshape_transform_huggingface(self, tensor, width, height):
|
| 33 |
+
result = tensor.reshape(tensor.size(0),
|
| 34 |
+
height,
|
| 35 |
+
width,
|
| 36 |
+
tensor.size(2))
|
| 37 |
+
result = result.transpose(2, 3).transpose(1, 2)
|
| 38 |
+
return result
|
| 39 |
+
|
| 40 |
+
def run_grad_cam_on_image(self,
|
| 41 |
+
targets_for_gradcam: List[Callable],
|
| 42 |
+
reshape_transform: Optional[Callable],
|
| 43 |
+
input_tensor: torch.nn.Module,
|
| 44 |
+
input_image: Image,
|
| 45 |
+
method: Callable=GradCAM):
|
| 46 |
+
with method(model=self.model,
|
| 47 |
+
target_layers=self.target_layer,
|
| 48 |
+
reshape_transform=reshape_transform) as cam:
|
| 49 |
+
|
| 50 |
+
# Replicate the tensor for each of the categories we want to create Grad-CAM for:
|
| 51 |
+
# print(input_tensor.size())
|
| 52 |
+
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
|
| 53 |
+
# print(repeated_tensor.size())
|
| 54 |
+
|
| 55 |
+
batch_results = cam(input_tensor=repeated_tensor,
|
| 56 |
+
targets=targets_for_gradcam)
|
| 57 |
+
results = []
|
| 58 |
+
for grayscale_cam in batch_results:
|
| 59 |
+
visualization = show_cam_on_image(np.float32(input_image) / 255,
|
| 60 |
+
grayscale_cam,
|
| 61 |
+
use_rgb=True)
|
| 62 |
+
# Make it weight less in the notebook:
|
| 63 |
+
visualization = cv2.resize(visualization,
|
| 64 |
+
(visualization.shape[1] // 1, visualization.shape[0] // 1))
|
| 65 |
+
results.append(visualization)
|
| 66 |
+
return np.hstack(results)
|
| 67 |
+
|
| 68 |
+
def get_cam(self, image, category_id):
|
| 69 |
+
image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width']))
|
| 70 |
+
img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze()
|
| 71 |
+
targets_for_gradcam = [ClassifierOutputTarget(category_id)]
|
| 72 |
+
reshape_transform = partial(self.swinT_reshape_transform_huggingface,
|
| 73 |
+
width=img_tensor.shape[2] // 32,
|
| 74 |
+
height=img_tensor.shape[1] // 32)
|
| 75 |
+
cam = self.run_grad_cam_on_image(input_tensor=img_tensor,
|
| 76 |
+
input_image=image,
|
| 77 |
+
targets_for_gradcam=targets_for_gradcam,
|
| 78 |
+
reshape_transform=reshape_transform)
|
| 79 |
+
|
| 80 |
+
return cam
|
glaucoma.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from transformers import AutoImageProcessor, Swinv2ForImageClassification
|
| 5 |
+
|
| 6 |
+
from lib.cam import ClassActivationMap
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GlaucomaModel(object):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
|
| 12 |
+
device=torch.device('cpu')):
|
| 13 |
+
# where to load the model, gpu or cpu ?
|
| 14 |
+
self.device = device
|
| 15 |
+
# classification model for nails disease
|
| 16 |
+
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
|
| 17 |
+
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
|
| 18 |
+
# class activation map
|
| 19 |
+
self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
|
| 20 |
+
|
| 21 |
+
# classification id to label
|
| 22 |
+
self.id2label = self.cls_model.config.id2label
|
| 23 |
+
|
| 24 |
+
# number of classes for nails disease
|
| 25 |
+
self.num_diseases = len(self.id2label)
|
| 26 |
+
|
| 27 |
+
def glaucoma_pred(self, image):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
image: image array in RGB order.
|
| 31 |
+
"""
|
| 32 |
+
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
inputs.to(self.device)
|
| 35 |
+
outputs = self.cls_model(**inputs).logits
|
| 36 |
+
disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
|
| 37 |
+
|
| 38 |
+
return disease_idx
|
| 39 |
+
|
| 40 |
+
def process(self, image):
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
image: image array in RGB order.
|
| 44 |
+
"""
|
| 45 |
+
image_shape = image.shape[:2]
|
| 46 |
+
disease_idx = self.glaucoma_pred(image)
|
| 47 |
+
cam = self.cam.get_cam(image, disease_idx)
|
| 48 |
+
cam = cv2.resize(cam, image_shape[::-1])
|
| 49 |
+
|
| 50 |
+
return disease_idx, cam
|
| 51 |
+
|