Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import cv2 | |
import argparse | |
import numpy as np | |
import gradio as gr | |
import numpy as np | |
from tqdm import tqdm | |
from PIL import Image, ImageEnhance | |
import torch | |
from torch.amp import autocast | |
import torch.nn.functional as F | |
from network.line_extractor import LineExtractor | |
def resize(image, max_size=3840): | |
h, w = image.shape[:2] | |
if h > w: | |
h, w = (max_size, int(w * max_size / h)) | |
else: | |
h, w = (int(h * max_size / w), max_size) | |
return cv2.resize(image, (w, h)) | |
def increase_sharpness(img, factor=6.0): | |
image = Image.fromarray(img) | |
enhancer = ImageEnhance.Sharpness(image) | |
return np.array(enhancer.enhance(factor)) | |
def load_model(mode): | |
if mode == 'basic': | |
model = LineExtractor(3, 1, True) | |
elif mode == 'detail': | |
model = LineExtractor(2, 1, True) | |
path_model = os.path.join('weights', f'{mode}.pth') | |
model.load_state_dict(torch.load(path_model, weights_only=True)) | |
for param in model.parameters(): | |
param.requires_grad = False | |
model.eval() | |
return model | |
def process_image(image, mode, binarize, threshold, fp16=True): | |
if image is None: | |
return None | |
binarize_value = threshold if binarize else -1 | |
args = argparse.Namespace(mode=mode, binarize=binarize_value, fp16=fp16, device="cuda:0") | |
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
if image.shape[0] > 1920 or image.shape[1] > 1920: | |
image = resize(image) | |
return inference(image, args) | |
def process_video(path_in, path_out, fourcc='mp4v', **kwargs): | |
video = cv2.VideoCapture(path_in) | |
fps = video.get(cv2.CAP_PROP_FPS) | |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fourcc = cv2.VideoWriter_fourcc(*fourcc) | |
video_out = cv2.VideoWriter(path_out, fourcc, fps, (width, height)) | |
for _ in tqdm(range(total_frames), desc='Processing Video'): | |
ret, frame = video.read() | |
if not ret: | |
break | |
img = inference(frame, **kwargs) | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
video_out.write(img) | |
video.release() | |
video_out.release() | |
def inference(img: np.ndarray, args): | |
if args.mode == 'basic': | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
img = increase_sharpness(img) | |
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(args.device) / 255. | |
x_in = img | |
else: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3) | |
sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3) | |
sobel = cv2.magnitude(sobelx, sobely) | |
sobel = 255 - cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1) | |
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255. | |
sobel = torch.from_numpy(sobel).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255. | |
x_in = torch.cat([img, sobel], dim=1) | |
B, C, H, W = x_in.shape | |
pad_h = 8 - (H % 8) | |
pad_w = 8 - (W % 8) | |
x_in = F.pad(x_in, (0, pad_w, 0, pad_h), mode='reflect') | |
with torch.no_grad(), autocast(enabled=args.fp16, device_type='cuda:0'): | |
if args.mode == 'basic': | |
pred = model_basic(x_in) | |
elif args.mode == 'detail': | |
pred = model_detail(x_in) | |
pred = pred[:, :, :H, :W] | |
if args.binarize != -1: | |
pred = (pred > args.binarize).float() | |
return np.clip((pred[0, 0].cpu().numpy() * 255) + 0.5, 0, 255).astype(np.uint8) | |
model_basic = load_model("basic").to("cuda:0") | |
model_detail = load_model("detail").to("cuda:0") | |
with gr.Blocks() as demo: | |
gr.Markdown("# AniLines - Anime Line Extractor Demo") | |
gr.Markdown("For video and batch processing, please refer to the [project page](https://github.com/zhenglinpan/AniLines-Anime-Line-Extractor)") | |
with gr.Tabs(): | |
with gr.Tab("Image Processing"): | |
gr.Markdown("## Process Images") | |
gr.Markdown("*Online demo resizes image to a max of 4K if larger.") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
image_output = gr.Image(label="Processed Output") | |
mode_dropdown = gr.Radio(["basic", "detail"], value="detail", label="Processing Mode") | |
binarize_checkbox = gr.Checkbox(label="Binarize", value=False) | |
binarize_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.75, label="Binarization Threshold (-1 for auto)", visible=False) | |
binarize_checkbox.change(lambda binarize: gr.update(visible=binarize), inputs=binarize_checkbox, outputs=binarize_slider) | |
process_button = gr.Button("Process") | |
gr.Examples( | |
examples=["example.png", "example2.jpg"], | |
inputs=image_input, | |
outputs=image_input | |
) | |
process_button.click(process_image, | |
inputs=[image_input, mode_dropdown, binarize_checkbox, binarize_slider], | |
outputs=image_output) | |
demo.queue().launch() | |