File size: 3,656 Bytes
d617811
 
 
 
 
 
1fdfa56
a5c51bb
d617811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fdfa56
d617811
 
 
 
 
 
 
 
 
 
 
 
 
1fdfa56
 
d617811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fdfa56
 
d617811
 
 
aff8d56
1fdfa56
 
d617811
 
 
 
 
 
 
aff8d56
d617811
 
 
 
e20de5f
aff8d56
d617811
 
 
 
 
 
 
 
 
e20de5f
 
d617811
 
 
aff8d56
d617811
36e8d62
aff8d56
 
 
14dfb70
aff8d56
 
d617811
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on

import tempfile
import time
import warnings

import cv2
import numpy as np
import tqdm

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger

from cat_seg import add_cat_seg_config
from demo.predictor import VisualizationDemo
import gradio as gr
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc

# constants
WINDOW_NAME = "MaskFormer demo"


def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_cat_seg_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    if torch.cuda.is_available():
        cfg.MODEL.DEVICE = "cuda"
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
    parser.add_argument(
        "--config-file",
        default="configs/vitl_swinb_384.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=(
        ["MODEL.WEIGHTS", "model_final.pth", 
        "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
        "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
        "TEST.SLIDING_WINDOW", "True",
        "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
        "MODEL.PROMPT_ENSEMBLE_TYPE", "single",
        "MODEL.DEVICE", "cpu"]),
        nargs=argparse.REMAINDER,
    )
    return parser

def save_masks(preds, text):
    preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
    for i, t in enumerate(text):
        dir = f"mask_{t}.png"
        mask = preds == i
        cv2.imwrite(dir, mask * 255)

def predict(image, text):
    predictions, visualized_output = demo.run_on_image(image, text)
    #save_masks(predictions, text.split(','))
    canvas = fc(visualized_output.fig)
    canvas.draw()
    out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))

    return out[..., ::-1]

if __name__ == "__main__":
    args = get_parser().parse_args()
    cfg = setup_cfg(args)
    global demo
    demo = VisualizationDemo(cfg)

    iface = gr.Interface(
        fn=predict,
        inputs=[gr.Image(), gr.Textbox(placeholder='cat, person, background')],
        outputs="image",
        examples=[['assets/nyancat.png', 'cat, pop tart, rainbow, background']],
        description="""## CAT-Seg Demo
Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.

Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.

To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
    iface.launch()