Sod_Inpaint / app.py
wenpeng's picture
add sod output
a579154
import gradio as gr
import inpaint.infer_model as inpaint
import sod.infer_model as sod
import numpy as np
import torch
import glob
import cv2
# import os
# cmd = 'sh download.sh'
# os.system(cmd)
device = torch.device("cpu")
print(device)
inpaint_model = inpaint.IVModel(device=device)
sod_model = sod.IVModel(device=device)
max_size=512
scale_factor = 8
count = 0
def sod_inpaint(img):
global count
h,w = img.shape[:2]
if max(h, w) > max_size:
if h < w:
h, w = int(max_size * h / w), max_size
else:
h, w = max_size, int(max_size * w / h)
h = h // scale_factor * scale_factor
w = w // scale_factor * scale_factor
img = cv2.resize(img, (w,h))
img = img[:,:,::-1]
sod_res = sod_model.forward(img,None)
sod_res = np.uint8(sod_res)
h,w = sod_res.shape[:2]
so = np.uint8(sod_res[:,:w//2,:] * (sod_res[:,w//2:,:]>0).astype(np.float32))
inpaint_res = inpaint_model.forward(sod_res,None)
inpaint_res = np.uint8(inpaint_res)
count +=1
print(count, ' images have been processed')
return so[:,:,::-1], inpaint_res[:,:,::-1]
examples = glob.glob('examples/*.*')
inputs = gr.inputs.Image(shape=(512,512), image_mode="RGB", invert_colors=False, source="upload", tool="editor", type="numpy", label=None, optional=False)
iface = gr.Interface(fn=sod_inpaint, inputs=inputs, outputs=["image", "image"], examples=examples, title='Salient Object Detection + Inpaint', description='Upload an image and you will see the fg and inpainted bg', theme='huggingface')
iface.launch()