Spaces:
Running
on
Zero
Running
on
Zero
rename repo
Browse files- .gitattributes +1 -0
- .gitignore +7 -0
- README.md +6 -6
- app.py +151 -0
- network/line_extractor.py +107 -0
- requirements.txt +8 -0
- weights/.gitkeep +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dev
|
2 |
+
__pycache__
|
3 |
+
.gradio
|
4 |
+
|
5 |
+
*.pth
|
6 |
+
*.png
|
7 |
+
*.jpg
|
README.md
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
short_description:
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Anilines
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.16.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
short_description: Anime Line Extractor
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
from PIL import Image, ImageEnhance
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.amp import autocast
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from network.line_extractor import LineExtractor
|
17 |
+
|
18 |
+
def resize(image, max_size=3840):
|
19 |
+
h, w = image.shape[:2]
|
20 |
+
if h > w:
|
21 |
+
h, w = (max_size, int(w * max_size / h))
|
22 |
+
else:
|
23 |
+
h, w = (int(h * max_size / w), max_size)
|
24 |
+
return cv2.resize(image, (w, h))
|
25 |
+
|
26 |
+
def increase_sharpness(img, factor=6.0):
|
27 |
+
image = Image.fromarray(img)
|
28 |
+
enhancer = ImageEnhance.Sharpness(image)
|
29 |
+
return np.array(enhancer.enhance(factor))
|
30 |
+
|
31 |
+
def load_model(mode):
|
32 |
+
if mode == 'basic':
|
33 |
+
model = LineExtractor(3, 1, True)
|
34 |
+
elif mode == 'detail':
|
35 |
+
model = LineExtractor(2, 1, True)
|
36 |
+
|
37 |
+
path_model = os.path.join('weights', f'{mode}.pth')
|
38 |
+
model.load_state_dict(torch.load(path_model, weights_only=True))
|
39 |
+
|
40 |
+
for param in model.parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
return model
|
45 |
+
|
46 |
+
def process_image(image, mode, binarize, threshold, fp16=True):
|
47 |
+
if image is None:
|
48 |
+
return None
|
49 |
+
|
50 |
+
binarize_value = threshold if binarize else -1
|
51 |
+
args = argparse.Namespace(mode=mode, binarize=binarize_value, fp16=fp16, device="cuda:0")
|
52 |
+
|
53 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
54 |
+
if image.shape[0] > 1920 or image.shape[1] > 1920:
|
55 |
+
image = resize(image)
|
56 |
+
|
57 |
+
return inference(image, args)
|
58 |
+
|
59 |
+
def process_video(path_in, path_out, fourcc='mp4v', **kwargs):
|
60 |
+
video = cv2.VideoCapture(path_in)
|
61 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
62 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
63 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
64 |
+
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
65 |
+
|
66 |
+
fourcc = cv2.VideoWriter_fourcc(*fourcc)
|
67 |
+
video_out = cv2.VideoWriter(path_out, fourcc, fps, (width, height))
|
68 |
+
|
69 |
+
for _ in tqdm(range(total_frames), desc='Processing Video'):
|
70 |
+
ret, frame = video.read()
|
71 |
+
if not ret:
|
72 |
+
break
|
73 |
+
|
74 |
+
img = inference(frame, **kwargs)
|
75 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
76 |
+
video_out.write(img)
|
77 |
+
|
78 |
+
video.release()
|
79 |
+
video_out.release()
|
80 |
+
|
81 |
+
@spaces.GPU(duration=60)
|
82 |
+
def inference(img: np.ndarray, args):
|
83 |
+
if args.mode == 'basic':
|
84 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
85 |
+
img = increase_sharpness(img)
|
86 |
+
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(args.device) / 255.
|
87 |
+
x_in = img
|
88 |
+
else:
|
89 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
90 |
+
|
91 |
+
sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3)
|
92 |
+
sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3)
|
93 |
+
sobel = cv2.magnitude(sobelx, sobely)
|
94 |
+
sobel = 255 - cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1)
|
95 |
+
|
96 |
+
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255.
|
97 |
+
sobel = torch.from_numpy(sobel).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255.
|
98 |
+
|
99 |
+
x_in = torch.cat([img, sobel], dim=1)
|
100 |
+
|
101 |
+
B, C, H, W = x_in.shape
|
102 |
+
pad_h = 8 - (H % 8)
|
103 |
+
pad_w = 8 - (W % 8)
|
104 |
+
x_in = F.pad(x_in, (0, pad_w, 0, pad_h), mode='reflect')
|
105 |
+
|
106 |
+
with torch.no_grad(), autocast(enabled=args.fp16, device_type='cuda:0'):
|
107 |
+
if args.mode == 'basic':
|
108 |
+
pred = model_basic(x_in)
|
109 |
+
elif args.mode == 'detail':
|
110 |
+
pred = model_detail(x_in)
|
111 |
+
pred = pred[:, :, :H, :W]
|
112 |
+
if args.binarize != -1:
|
113 |
+
pred = (pred > args.binarize).float()
|
114 |
+
|
115 |
+
return np.clip((pred[0, 0].cpu().numpy() * 255) + 0.5, 0, 255).astype(np.uint8)
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
model_basic = load_model("basic").to("cuda:0")
|
120 |
+
model_detail = load_model("detail").to("cuda:0")
|
121 |
+
|
122 |
+
with gr.Blocks() as demo:
|
123 |
+
gr.Markdown("# AniLines - Anime Line Extractor Demo")
|
124 |
+
gr.Markdown("For video and batch processing, please refer to the [project page](https://github.com/zhenglinpan/AniLines-Anime-Line-Extractor)")
|
125 |
+
|
126 |
+
with gr.Tabs():
|
127 |
+
with gr.Tab("Image Processing"):
|
128 |
+
gr.Markdown("## Process Images")
|
129 |
+
gr.Markdown("*Online demo resizes image to a max of 4K if larger.")
|
130 |
+
with gr.Row():
|
131 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
132 |
+
image_output = gr.Image(label="Processed Output")
|
133 |
+
|
134 |
+
mode_dropdown = gr.Radio(["basic", "detail"], value="detail", label="Processing Mode")
|
135 |
+
binarize_checkbox = gr.Checkbox(label="Binarize", value=False)
|
136 |
+
binarize_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.75, label="Binarization Threshold (-1 for auto)", visible=False)
|
137 |
+
binarize_checkbox.change(lambda binarize: gr.update(visible=binarize), inputs=binarize_checkbox, outputs=binarize_slider)
|
138 |
+
|
139 |
+
process_button = gr.Button("Process")
|
140 |
+
|
141 |
+
gr.Examples(
|
142 |
+
examples=["example.png", "example2.jpg"],
|
143 |
+
inputs=image_input,
|
144 |
+
outputs=image_input
|
145 |
+
)
|
146 |
+
|
147 |
+
process_button.click(process_image,
|
148 |
+
inputs=[image_input, mode_dropdown, binarize_checkbox, binarize_slider],
|
149 |
+
outputs=image_output)
|
150 |
+
|
151 |
+
demo.queue().launch()
|
network/line_extractor.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class LineExtractor(nn.Module):
|
7 |
+
def __init__(self, chan_in, chan_out, bilinear=False):
|
8 |
+
super().__init__()
|
9 |
+
self.bilinear = bilinear
|
10 |
+
|
11 |
+
self.inc = (DoubleConv(chan_in, 64))
|
12 |
+
self.down1 = (Down(64, 128))
|
13 |
+
self.down2 = (Down(128, 256))
|
14 |
+
self.down3 = (Down(256, 512))
|
15 |
+
factor = 2 if bilinear else 1
|
16 |
+
self.down4 = (Down(512, 1024 // factor))
|
17 |
+
self.up1 = (Up(1024, 512 // factor, bilinear))
|
18 |
+
self.up2 = (Up(512, 256 // factor, bilinear))
|
19 |
+
self.up3 = (Up(256, 128 // factor, bilinear))
|
20 |
+
self.up4 = (Up(128, 64, bilinear))
|
21 |
+
self.outc = (OutConv(64, chan_out))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x1 = self.inc(x)
|
25 |
+
x2 = self.down1(x1)
|
26 |
+
x3 = self.down2(x2)
|
27 |
+
x4 = self.down3(x3)
|
28 |
+
x5 = self.down4(x4)
|
29 |
+
x = self.up1(x5, x4)
|
30 |
+
x = self.up2(x, x3)
|
31 |
+
x = self.up3(x, x2)
|
32 |
+
x = self.up4(x, x1)
|
33 |
+
logits = self.outc(x)
|
34 |
+
return logits
|
35 |
+
|
36 |
+
def use_checkpointing(self):
|
37 |
+
self.inc = torch.utils.checkpoint(self.inc)
|
38 |
+
self.down1 = torch.utils.checkpoint(self.down1)
|
39 |
+
self.down2 = torch.utils.checkpoint(self.down2)
|
40 |
+
self.down3 = torch.utils.checkpoint(self.down3)
|
41 |
+
self.down4 = torch.utils.checkpoint(self.down4)
|
42 |
+
self.up1 = torch.utils.checkpoint(self.up1)
|
43 |
+
self.up2 = torch.utils.checkpoint(self.up2)
|
44 |
+
self.up3 = torch.utils.checkpoint(self.up3)
|
45 |
+
self.up4 = torch.utils.checkpoint(self.up4)
|
46 |
+
self.outc = torch.utils.checkpoint(self.outc)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
class DoubleConv(nn.Module):
|
51 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
52 |
+
super().__init__()
|
53 |
+
if not mid_channels:
|
54 |
+
mid_channels = out_channels
|
55 |
+
self.double_conv = nn.Sequential(
|
56 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
57 |
+
nn.BatchNorm2d(mid_channels),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
60 |
+
nn.BatchNorm2d(out_channels),
|
61 |
+
nn.ReLU(inplace=True)
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return self.double_conv(x)
|
66 |
+
|
67 |
+
|
68 |
+
class Down(nn.Module):
|
69 |
+
def __init__(self, in_channels, out_channels):
|
70 |
+
super().__init__()
|
71 |
+
self.maxpool_conv = nn.Sequential(
|
72 |
+
nn.MaxPool2d(2),
|
73 |
+
DoubleConv(in_channels, out_channels)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
return self.maxpool_conv(x)
|
78 |
+
|
79 |
+
|
80 |
+
class Up(nn.Module):
|
81 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
82 |
+
super().__init__()
|
83 |
+
if bilinear:
|
84 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
85 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
86 |
+
else:
|
87 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
88 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
89 |
+
|
90 |
+
def forward(self, x1, x2):
|
91 |
+
x1 = self.up(x1)
|
92 |
+
diffY = x2.size()[2] - x1.size()[2]
|
93 |
+
diffX = x2.size()[3] - x1.size()[3]
|
94 |
+
|
95 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
96 |
+
diffY // 2, diffY - diffY // 2])
|
97 |
+
x = torch.cat([x2, x1], dim=1)
|
98 |
+
return self.conv(x)
|
99 |
+
|
100 |
+
|
101 |
+
class OutConv(nn.Module):
|
102 |
+
def __init__(self, in_channels, out_channels):
|
103 |
+
super(OutConv, self).__init__()
|
104 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
return self.conv(x)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
opencv-python
|
5 |
+
pillow
|
6 |
+
numpy
|
7 |
+
gradio
|
8 |
+
spaces
|
weights/.gitkeep
ADDED
File without changes
|