aidenpan commited on
Commit
9a514db
·
1 Parent(s): 4557021

rename repo

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. .gitignore +7 -0
  3. README.md +6 -6
  4. app.py +151 -0
  5. network/line_extractor.py +107 -0
  6. requirements.txt +8 -0
  7. weights/.gitkeep +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dev
2
+ __pycache__
3
+ .gradio
4
+
5
+ *.pth
6
+ *.png
7
+ *.jpg
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: AniLines Anime Lineart Extractor
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.18.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Extracting lineart, sketch from anime images and videos
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Anilines
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.16.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Anime Line Extractor
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import cv2
5
+ import argparse
6
+ import numpy as np
7
+ import gradio as gr
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from PIL import Image, ImageEnhance
11
+
12
+ import torch
13
+ from torch.amp import autocast
14
+ import torch.nn.functional as F
15
+
16
+ from network.line_extractor import LineExtractor
17
+
18
+ def resize(image, max_size=3840):
19
+ h, w = image.shape[:2]
20
+ if h > w:
21
+ h, w = (max_size, int(w * max_size / h))
22
+ else:
23
+ h, w = (int(h * max_size / w), max_size)
24
+ return cv2.resize(image, (w, h))
25
+
26
+ def increase_sharpness(img, factor=6.0):
27
+ image = Image.fromarray(img)
28
+ enhancer = ImageEnhance.Sharpness(image)
29
+ return np.array(enhancer.enhance(factor))
30
+
31
+ def load_model(mode):
32
+ if mode == 'basic':
33
+ model = LineExtractor(3, 1, True)
34
+ elif mode == 'detail':
35
+ model = LineExtractor(2, 1, True)
36
+
37
+ path_model = os.path.join('weights', f'{mode}.pth')
38
+ model.load_state_dict(torch.load(path_model, weights_only=True))
39
+
40
+ for param in model.parameters():
41
+ param.requires_grad = False
42
+ model.eval()
43
+
44
+ return model
45
+
46
+ def process_image(image, mode, binarize, threshold, fp16=True):
47
+ if image is None:
48
+ return None
49
+
50
+ binarize_value = threshold if binarize else -1
51
+ args = argparse.Namespace(mode=mode, binarize=binarize_value, fp16=fp16, device="cuda:0")
52
+
53
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
54
+ if image.shape[0] > 1920 or image.shape[1] > 1920:
55
+ image = resize(image)
56
+
57
+ return inference(image, args)
58
+
59
+ def process_video(path_in, path_out, fourcc='mp4v', **kwargs):
60
+ video = cv2.VideoCapture(path_in)
61
+ fps = video.get(cv2.CAP_PROP_FPS)
62
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
63
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
64
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
65
+
66
+ fourcc = cv2.VideoWriter_fourcc(*fourcc)
67
+ video_out = cv2.VideoWriter(path_out, fourcc, fps, (width, height))
68
+
69
+ for _ in tqdm(range(total_frames), desc='Processing Video'):
70
+ ret, frame = video.read()
71
+ if not ret:
72
+ break
73
+
74
+ img = inference(frame, **kwargs)
75
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
76
+ video_out.write(img)
77
+
78
+ video.release()
79
+ video_out.release()
80
+
81
+ @spaces.GPU(duration=60)
82
+ def inference(img: np.ndarray, args):
83
+ if args.mode == 'basic':
84
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
85
+ img = increase_sharpness(img)
86
+ img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(args.device) / 255.
87
+ x_in = img
88
+ else:
89
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
90
+
91
+ sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3)
92
+ sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3)
93
+ sobel = cv2.magnitude(sobelx, sobely)
94
+ sobel = 255 - cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8UC1)
95
+
96
+ img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255.
97
+ sobel = torch.from_numpy(sobel).unsqueeze(0).unsqueeze(0).float().to(args.device) / 255.
98
+
99
+ x_in = torch.cat([img, sobel], dim=1)
100
+
101
+ B, C, H, W = x_in.shape
102
+ pad_h = 8 - (H % 8)
103
+ pad_w = 8 - (W % 8)
104
+ x_in = F.pad(x_in, (0, pad_w, 0, pad_h), mode='reflect')
105
+
106
+ with torch.no_grad(), autocast(enabled=args.fp16, device_type='cuda:0'):
107
+ if args.mode == 'basic':
108
+ pred = model_basic(x_in)
109
+ elif args.mode == 'detail':
110
+ pred = model_detail(x_in)
111
+ pred = pred[:, :, :H, :W]
112
+ if args.binarize != -1:
113
+ pred = (pred > args.binarize).float()
114
+
115
+ return np.clip((pred[0, 0].cpu().numpy() * 255) + 0.5, 0, 255).astype(np.uint8)
116
+
117
+
118
+
119
+ model_basic = load_model("basic").to("cuda:0")
120
+ model_detail = load_model("detail").to("cuda:0")
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown("# AniLines - Anime Line Extractor Demo")
124
+ gr.Markdown("For video and batch processing, please refer to the [project page](https://github.com/zhenglinpan/AniLines-Anime-Line-Extractor)")
125
+
126
+ with gr.Tabs():
127
+ with gr.Tab("Image Processing"):
128
+ gr.Markdown("## Process Images")
129
+ gr.Markdown("*Online demo resizes image to a max of 4K if larger.")
130
+ with gr.Row():
131
+ image_input = gr.Image(type="pil", label="Upload Image")
132
+ image_output = gr.Image(label="Processed Output")
133
+
134
+ mode_dropdown = gr.Radio(["basic", "detail"], value="detail", label="Processing Mode")
135
+ binarize_checkbox = gr.Checkbox(label="Binarize", value=False)
136
+ binarize_slider = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.75, label="Binarization Threshold (-1 for auto)", visible=False)
137
+ binarize_checkbox.change(lambda binarize: gr.update(visible=binarize), inputs=binarize_checkbox, outputs=binarize_slider)
138
+
139
+ process_button = gr.Button("Process")
140
+
141
+ gr.Examples(
142
+ examples=["example.png", "example2.jpg"],
143
+ inputs=image_input,
144
+ outputs=image_input
145
+ )
146
+
147
+ process_button.click(process_image,
148
+ inputs=[image_input, mode_dropdown, binarize_checkbox, binarize_slider],
149
+ outputs=image_output)
150
+
151
+ demo.queue().launch()
network/line_extractor.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class LineExtractor(nn.Module):
7
+ def __init__(self, chan_in, chan_out, bilinear=False):
8
+ super().__init__()
9
+ self.bilinear = bilinear
10
+
11
+ self.inc = (DoubleConv(chan_in, 64))
12
+ self.down1 = (Down(64, 128))
13
+ self.down2 = (Down(128, 256))
14
+ self.down3 = (Down(256, 512))
15
+ factor = 2 if bilinear else 1
16
+ self.down4 = (Down(512, 1024 // factor))
17
+ self.up1 = (Up(1024, 512 // factor, bilinear))
18
+ self.up2 = (Up(512, 256 // factor, bilinear))
19
+ self.up3 = (Up(256, 128 // factor, bilinear))
20
+ self.up4 = (Up(128, 64, bilinear))
21
+ self.outc = (OutConv(64, chan_out))
22
+
23
+ def forward(self, x):
24
+ x1 = self.inc(x)
25
+ x2 = self.down1(x1)
26
+ x3 = self.down2(x2)
27
+ x4 = self.down3(x3)
28
+ x5 = self.down4(x4)
29
+ x = self.up1(x5, x4)
30
+ x = self.up2(x, x3)
31
+ x = self.up3(x, x2)
32
+ x = self.up4(x, x1)
33
+ logits = self.outc(x)
34
+ return logits
35
+
36
+ def use_checkpointing(self):
37
+ self.inc = torch.utils.checkpoint(self.inc)
38
+ self.down1 = torch.utils.checkpoint(self.down1)
39
+ self.down2 = torch.utils.checkpoint(self.down2)
40
+ self.down3 = torch.utils.checkpoint(self.down3)
41
+ self.down4 = torch.utils.checkpoint(self.down4)
42
+ self.up1 = torch.utils.checkpoint(self.up1)
43
+ self.up2 = torch.utils.checkpoint(self.up2)
44
+ self.up3 = torch.utils.checkpoint(self.up3)
45
+ self.up4 = torch.utils.checkpoint(self.up4)
46
+ self.outc = torch.utils.checkpoint(self.outc)
47
+
48
+
49
+
50
+ class DoubleConv(nn.Module):
51
+ def __init__(self, in_channels, out_channels, mid_channels=None):
52
+ super().__init__()
53
+ if not mid_channels:
54
+ mid_channels = out_channels
55
+ self.double_conv = nn.Sequential(
56
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
57
+ nn.BatchNorm2d(mid_channels),
58
+ nn.ReLU(inplace=True),
59
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
60
+ nn.BatchNorm2d(out_channels),
61
+ nn.ReLU(inplace=True)
62
+ )
63
+
64
+ def forward(self, x):
65
+ return self.double_conv(x)
66
+
67
+
68
+ class Down(nn.Module):
69
+ def __init__(self, in_channels, out_channels):
70
+ super().__init__()
71
+ self.maxpool_conv = nn.Sequential(
72
+ nn.MaxPool2d(2),
73
+ DoubleConv(in_channels, out_channels)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.maxpool_conv(x)
78
+
79
+
80
+ class Up(nn.Module):
81
+ def __init__(self, in_channels, out_channels, bilinear=True):
82
+ super().__init__()
83
+ if bilinear:
84
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
85
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
86
+ else:
87
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
88
+ self.conv = DoubleConv(in_channels, out_channels)
89
+
90
+ def forward(self, x1, x2):
91
+ x1 = self.up(x1)
92
+ diffY = x2.size()[2] - x1.size()[2]
93
+ diffX = x2.size()[3] - x1.size()[3]
94
+
95
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
96
+ diffY // 2, diffY - diffY // 2])
97
+ x = torch.cat([x2, x1], dim=1)
98
+ return self.conv(x)
99
+
100
+
101
+ class OutConv(nn.Module):
102
+ def __init__(self, in_channels, out_channels):
103
+ super(OutConv, self).__init__()
104
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
105
+
106
+ def forward(self, x):
107
+ return self.conv(x)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ torch
3
+ torchvision
4
+ opencv-python
5
+ pillow
6
+ numpy
7
+ gradio
8
+ spaces
weights/.gitkeep ADDED
File without changes