import spaces import os import cv2 import argparse import numpy as np import gradio as gr import numpy as np from tqdm import tqdm from PIL import Image, ImageEnhance import torch from torch.amp import autocast import torch.nn.functional as F from network.line_extractor import LineExtractor def resize(image, max_size=3840): h, w = image.shape[:2] if h > w: h, w = (max_size, int(w * max_size / h)) else: h, w = (int(h * max_size / w), max_size) return cv2.resize(image, (w, h)) def increase_sharpness(img, factor=6.0): image = Image.fromarray(img) enhancer = ImageEnhance.Sharpness(image) return np.array(enhancer.enhance(factor)) def load_model(mode): if mode == 'basic': model = LineExtractor(3, 1, True) elif mode == 'detail': model = LineExtractor(2, 1, True) path_model = os.path.join('weights', f'{mode}.pth') model.load_state_dict(torch.load(path_model, weights_only=True)) for param in model.parameters(): param.requires_grad = False model.eval() return model def process_image(image, mode, binarize, threshold, fp16=True): if image is None: return None binarize_value = threshold if binarize else -1 args = argparse.Namespace(mode=mode, binarize=binarize_value, fp16=fp16, device="cuda:0") image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) if image.shape[0] > 1920 or image.shape[1] > 1920: image = resize(image) return inference(image, args) def process_video(path_in, path_out, fourcc='mp4v', **kwargs): video = cv2.VideoCapture(path_in) fps = video.get(cv2.CAP_PROP_FPS) width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc(*fourcc) video_out = cv2.VideoWriter(path_out, fourcc, fps, (width, height)) for _ in tqdm(range(total_frames), desc='Processing Video'): ret, frame = video.read() if not ret: break img = inference(frame, **kwargs) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) video_out.write(img) video.release() video_out.release() @spaces.GPU(duration=60) def inference(img: np.ndarray, args): if args.mode == 'basic': img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = increase_sharpness(img) img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(args.device) / 255. x_in = img else: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3) sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3) sobel = cv2.magnitude(sobelx, sobely) sobel = 255 - cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1) img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255. sobel = torch.from_numpy(sobel).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255. x_in = torch.cat([img, sobel], dim=1) B, C, H, W = x_in.shape pad_h = 8 - (H % 8) pad_w = 8 - (W % 8) x_in = F.pad(x_in, (0, pad_w, 0, pad_h), mode='reflect') with torch.no_grad(), autocast(enabled=args.fp16, device_type='cuda:0'): if args.mode == 'basic': pred = model_basic(x_in) elif args.mode == 'detail': pred = model_detail(x_in) pred = pred[:, :, :H, :W] if args.binarize != -1: pred = (pred > args.binarize).float() return np.clip((pred[0, 0].cpu().numpy() * 255) + 0.5, 0, 255).astype(np.uint8) model_basic = load_model("basic").to("cuda:0") model_detail = load_model("detail").to("cuda:0") with gr.Blocks() as demo: gr.Markdown("# AniLines - Anime Line Extractor Demo") gr.Markdown("For video and batch processing, please refer to the [project page](https://github.com/zhenglinpan/AniLines-Anime-Line-Extractor)") with gr.Tabs(): with gr.Tab("Image Processing"): gr.Markdown("## Process Images") gr.Markdown("*Online demo resizes image to a max of 4K if larger.") with gr.Row(): image_input = gr.Image(type="pil", label="Upload Image") image_output = gr.Image(label="Processed Output") mode_dropdown = gr.Radio(["basic", "detail"], value="detail", label="Processing Mode") binarize_checkbox = gr.Checkbox(label="Binarize", value=False) binarize_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.75, label="Binarization Threshold (-1 for auto)", visible=False) binarize_checkbox.change(lambda binarize: gr.update(visible=binarize), inputs=binarize_checkbox, outputs=binarize_slider) process_button = gr.Button("Process") gr.Examples( examples=["example.png", "example2.jpg"], inputs=image_input, outputs=image_input ) process_button.click(process_image, inputs=[image_input, mode_dropdown, binarize_checkbox, binarize_slider], outputs=image_output) demo.queue().launch()