Magic-Drawings / app.py
fantos's picture
Update app.py
d0dbba0 verified
raw
history blame
5.88 kB
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
import os
# CPU 전용 설정
torch.set_num_threads(4) # CPU 스레드 수 제한
torch.set_grad_enabled(False) # 추론 모드만 사용
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
# Initial convolution block
model0 = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True)
]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
# Residual blocks
model2 = []
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [
nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)
]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
# CPU 전용 모델 로드
def load_models():
try:
print("Initializing models in CPU mode...")
model1 = Generator(3, 1, 3)
model2 = Generator(3, 1, 3)
# Load models in CPU mode
model1.load_state_dict(torch.load('model.pth', map_location='cpu'))
model2.load_state_dict(torch.load('model2.pth', map_location='cpu'))
model1.eval()
model2.eval()
print("Models loaded successfully")
return model1, model2
except Exception as e:
print(f"Error loading models: {str(e)}")
raise gr.Error("Failed to initialize models. Please check model files.")
try:
print("Starting model initialization...")
model1, model2 = load_models()
print("Model initialization completed")
except Exception as e:
print(f"Critical error: {str(e)}")
raise gr.Error("Failed to start the application")
def process_image(input_img, version, line_thickness=1.0):
try:
# 이미지 로드 및 전처리
original_img = Image.open(input_img)
original_size = original_img.size
transform = transforms.Compose([
transforms.Resize(256, Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(original_img).unsqueeze(0)
# 모델 처리
with torch.no_grad():
if version == 'Simple Lines':
output = model2(input_tensor)
else:
output = model1(input_tensor)
output = output * line_thickness
# 결과 이미지 생성
output_img = transforms.ToPILImage()(output.squeeze().clamp(0, 1))
output_img = output_img.resize(original_size, Image.BICUBIC)
return output_img
except Exception as e:
raise gr.Error(f"이미지 처리 에러: {str(e)}")
# Simple UI
with gr.Blocks() as iface:
gr.Markdown("# ✨ Magic Drawings")
gr.Markdown("Transform your photos into magical line art with AI")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="filepath", label="Upload Image")
version = gr.Radio(
choices=['Complex Lines', 'Simple Lines'],
value='Simple Lines',
label="Art Style"
)
line_thickness = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Line Thickness"
)
with gr.Column():
output_image = gr.Image(type="pil", label="Generated Art")
generate_btn = gr.Button("Generate Magic", variant="primary")
# Event handlers
generate_btn.click(
fn=process_image,
inputs=[input_image, version, line_thickness],
outputs=output_image
)
# 실행
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)