Spaces:
Runtime error
Runtime error
| import argparse | |
| import requests | |
| import logging | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from timm.data import create_transform | |
| from config import get_config | |
| from collections import OrderedDict | |
| os.system("python -m pip install -e .") | |
| os.system("pip install opencv-python timm diffdist h5py sklearn ftfy") | |
| os.system("pip install git+https://github.com/lvis-dataset/lvis-api.git") | |
| import detectron2.utils.comm as comm | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.config import get_cfg | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.engine import DefaultTrainer as Trainer | |
| from detectron2.engine import default_argument_parser, default_setup, hooks, launch | |
| from detectron2.evaluation import ( | |
| CityscapesInstanceEvaluator, | |
| CityscapesSemSegEvaluator, | |
| COCOEvaluator, | |
| COCOPanopticEvaluator, | |
| DatasetEvaluators, | |
| LVISEvaluator, | |
| PascalVOCDetectionEvaluator, | |
| SemSegEvaluator, | |
| verify_results, | |
| FLICKR30KEvaluator, | |
| ) | |
| from detectron2.modeling import GeneralizedRCNNWithTTA | |
| def parse_option(): | |
| parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False) | |
| parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', ) | |
| args, unparsed = parser.parse_known_args() | |
| return args | |
| def build_transforms(img_size, center_crop=True): | |
| t = [] | |
| if center_crop: | |
| size = int((256 / 224) * img_size) | |
| t.append( | |
| transforms.Resize(size) | |
| ) | |
| t.append( | |
| transforms.CenterCrop(img_size) | |
| ) | |
| else: | |
| t.append( | |
| transforms.Resize(img_size) | |
| ) | |
| t.append(transforms.ToTensor()) | |
| return transforms.Compose(t) | |
| def setup(args): | |
| """ | |
| Create configs and perform basic setups. | |
| """ | |
| cfg = get_cfg() | |
| cfg.merge_from_file(args.config_file) | |
| cfg.freeze() | |
| default_setup(cfg, args) | |
| return cfg | |
| ''' | |
| build model | |
| ''' | |
| args = parse_option() | |
| cfg = setup(args) | |
| model = Trainer.build_model(cfg) | |
| DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( | |
| cfg.MODEL.WEIGHTS, resume=False | |
| ) | |
| if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \ | |
| and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\ | |
| and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model | |
| DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load( | |
| cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False | |
| ) | |
| ''' | |
| build data transform | |
| ''' | |
| eval_transforms = build_transforms(800, center_crop=False) | |
| # display_transforms = build_transforms4display(960, center_crop=False) | |
| def localize_object(image, texts): | |
| img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255 | |
| model.eval() | |
| with torch.no_grad(): | |
| res = model(texts, [{"image": img_t}]) | |
| return res | |
| image = gr.inputs.Image() | |
| gr.Interface( | |
| description="Zero-Shot Object Detection with RegionCLIP (https://github.com/microsoft/RegionCLIP)", | |
| fn=localize_object, | |
| inputs=["image", "text"], | |
| outputs=[ | |
| gr.outputs.Image( | |
| type="pil", | |
| label="grounding results"), | |
| ], | |
| examples=[ | |
| ["./birds.png", "a goldfinch"], | |
| ["./apples_six.jpg", "a yellow apple"], | |
| ["./wines.jpg", "milk shake"], | |
| ["./logos.jpg", "a microsoft logo"], | |
| ], | |
| ).launch() | |