Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| from torchvision import transforms | |
| from transformers import AutoConfig, AutoModel | |
| from transformers import AutoModel | |
| from focusondepth.model_config import FocusOnDepthConfig | |
| from focusondepth.model_definition import FocusOnDepth | |
| AutoConfig.register("focusondepth", FocusOnDepthConfig) | |
| AutoModel.register(FocusOnDepthConfig, FocusOnDepth) | |
| original_image_cache = {} | |
| transform = transforms.Compose([ | |
| transforms.Resize((384, 384)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True) | |
| model.load_state_dict(torch.load('./focusondepth/FocusOnDepth_vit_base_patch16_384.p', map_location=torch.device('cpu'))['model_state_dict']) | |
| def inference(input_image): | |
| global model, transform | |
| model.eval() | |
| input_image = Image.fromarray(input_image) | |
| original_size = input_image.size | |
| tensor_image = transform(input_image) | |
| depth, segmentation = model(tensor_image.unsqueeze(0)) | |
| depth = 1-depth | |
| depth = transforms.ToPILImage()(depth[0, :]) | |
| segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float()) | |
| return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)] | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.inputs.Image(label="Input Image"), | |
| outputs = [ | |
| gr.outputs.Image(label="Depth Map:"), | |
| gr.outputs.Image(label="Segmentation Map:"), | |
| ], | |
| ) | |
| iface.launch() |