Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torchvision import transforms | |
| from transformers import SegformerForSemanticSegmentation | |
| # examples | |
| os.system("wget -O 073.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/073.png") | |
| os.system("wget -O 356.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/356.png") | |
| os.system("wget -O 599.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/599.png") | |
| os.system("wget -O 630.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/630.png") | |
| os.system("wget -O 673.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/673.png") | |
| os.system("wget -O 019.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/019.png") | |
| os.system("wget -O 261.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/261.png") | |
| os.system("wget -O 524.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/524.png") | |
| os.system("wget -O 716.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/716.png") | |
| os.system("wget -O 898.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/898.png") | |
| # model-setting | |
| MODEL_PATH="./best_model_mixto/" | |
| device = torch.device("cpu") | |
| preprocessor = transforms.Compose([ | |
| transforms.Resize(128), | |
| transforms.ToTensor() | |
| ]) | |
| model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH) | |
| model.eval() | |
| # inference-functions | |
| def upscale_logits(logit_outputs, size): | |
| """Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input""" | |
| return nn.functional.interpolate( | |
| logit_outputs, | |
| size=size, | |
| mode="bilinear", | |
| align_corners=False | |
| ) | |
| def visualize_instance_seg_mask(mask): | |
| """Agrega colores RGB a cada una de las clases en la mask""" | |
| image = np.zeros((mask.shape[0], mask.shape[1], 3)) | |
| labels = np.unique(mask) | |
| label2color = {label: (random.randint(0, 1), | |
| random.randint(0, 255), | |
| random.randint(0, 255)) for label in labels} | |
| for i in range(image.shape[0]): | |
| for j in range(image.shape[1]): | |
| image[i, j, :] = label2color[mask[i, j]] | |
| image = image / 255 | |
| return image | |
| def query_image(img): | |
| """Función para generar predicciones a la escala origina""" | |
| inputs = preprocessor(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| preds = model(inputs)["logits"] | |
| preds_upscale = upscale_logits(preds, preds.shape[2]) | |
| predict_label = torch.argmax(preds_upscale, dim=1).to(device) | |
| result = predict_label[0,:,:].detach().cpu().numpy() | |
| return visualize_instance_seg_mask(result) | |
| # demo | |
| demo = gr.Interface( | |
| query_image, | |
| inputs=[gr.Image(type="pil").style(full_width=True, height=256)], | |
| outputs=[gr.Image().style(full_width=True, height=256)], | |
| title="Skyguard: segmentador de glaciares de roca 🛰️ +️ 🛡️ ️", | |
| description="Modelo de segmentación de imágenes para detectar glaciares de roca.<br> Se entrenó un modelo [nvidia/SegFormer](https://huggingface.co/nvidia/mit-b0) con _fine-tuning_ en el [rock-glacier-dataset](https://huggingface.co/datasets/alkzar90/rock-glacier-dataset)", | |
| examples=[["073.png"], ["356.png"], ["599.png"], ["630.png"], ["673.png"], | |
| ["019.png"], ["261.png"], ["524.png"], ["716.png"], ["898.png"]], | |
| cache_examples=False | |
| ) | |
| demo.launch() | |