Spaces:
Paused
Paused
import spaces | |
import random | |
import torch | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
from huggingface_hub import snapshot_download | |
from transformers import pipeline | |
from diffusers.utils import load_image | |
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline | |
from kolors.models.modeling_chatglm import ChatGLMModel | |
from kolors.models.tokenization_chatglm import ChatGLMTokenizer | |
from kolors.models.controlnet import ControlNetModel | |
from diffusers import AutoencoderKL | |
from kolors.models.unet_2d_condition import UNet2DConditionModel | |
from diffusers import EulerDiscreteScheduler | |
from PIL import Image, ImageDraw, ImageFont | |
import os | |
device = "cuda" | |
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors") | |
ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny") | |
# Add translation pipeline | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device) | |
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') | |
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device) | |
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") | |
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device) | |
controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device) | |
pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline( | |
vae=vae, | |
controlnet=controlnet_canny, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
force_zeros_for_empty_prompt=False | |
) | |
def translate_korean_to_english(text): | |
if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): # Check if Korean characters are present | |
translated = translator(text, max_length=512)[0]['translation_text'] | |
return translated | |
return text | |
def HWC3(x): | |
assert x.dtype == np.uint8 | |
if x.ndim == 2: | |
x = x[:, :, None] | |
assert x.ndim == 3 | |
H, W, C = x.shape | |
assert C == 1 or C == 3 or C == 4 | |
if C == 3: | |
return x | |
if C == 1: | |
return np.concatenate([x, x, x], axis=2) | |
if C == 4: | |
color = x[:, :, 0:3].astype(np.float32) | |
alpha = x[:, :, 3:4].astype(np.float32) / 255.0 | |
y = color * alpha + 255.0 * (1.0 - alpha) | |
y = y.clip(0, 255).astype(np.uint8) | |
return y | |
def process_canny_condition(image, canny_threods=[100,200]): | |
np_image = np.array(image) | |
np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1]) | |
np_image = np_image[:, :, None] | |
np_image = np.concatenate([np_image, np_image, np_image], axis=2) | |
np_image = HWC3(np_image) | |
return Image.fromarray(np_image) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
def resize_image(image, resolution): | |
w, h = image.size | |
ratio = resolution / max(w, h) | |
new_w = int(w * ratio) | |
new_h = int(h * ratio) | |
return image.resize((new_w, new_h), Image.LANCZOS) | |
def text_to_image(text, size=72, position="middle-center"): | |
width, height = 1024, 576 | |
image = Image.new("RGB", (width, height), "white") | |
draw = ImageDraw.Draw(image) | |
font_files = ["Arial_Unicode.ttf"] | |
font = None | |
for font_file in font_files: | |
font_path = os.path.join(os.path.dirname(__file__), font_file) | |
if os.path.exists(font_path): | |
try: | |
font = ImageFont.truetype(font_path, size=size) | |
print(f"Using font: {font_file}") | |
break | |
except IOError: | |
print(f"Error loading font: {font_file}") | |
if font is None: | |
print("No suitable font found. Using default font.") | |
font = ImageFont.load_default() | |
lines = text.split('\n') | |
max_line_width = 0 | |
total_height = 0 | |
line_heights = [] | |
for line in lines: | |
left, top, right, bottom = draw.textbbox((0, 0), line, font=font) | |
line_width = right - left | |
line_height = bottom - top | |
line_heights.append(line_height) | |
max_line_width = max(max_line_width, line_width) | |
total_height += line_height | |
position_mapping = { | |
"top-left": (10, 10), | |
"top-left-center": (width // 4 - max_line_width // 2, 10), | |
"top-center": ((width - max_line_width) / 2, 10), | |
"top-right-center": (3 * width // 4 - max_line_width // 2, 10), | |
"top-right": (width - max_line_width - 10, 10), | |
"upper-left": (10, height // 4 - total_height // 2), | |
"upper-left-center": (width // 4 - max_line_width // 2, height // 4 - total_height // 2), | |
"upper-center": ((width - max_line_width) / 2, height // 4 - total_height // 2), | |
"upper-right-center": (3 * width // 4 - max_line_width // 2, height // 4 - total_height // 2), | |
"upper-right": (width - max_line_width - 10, height // 4 - total_height // 2), | |
"middle-left": (10, (height - total_height) / 2), | |
"middle-left-center": (width // 4 - max_line_width // 2, (height - total_height) / 2), | |
"middle-center": ((width - max_line_width) / 2, (height - total_height) / 2), | |
"middle-right-center": (3 * width // 4 - max_line_width // 2, (height - total_height) / 2), | |
"middle-right": (width - max_line_width - 10, (height - total_height) / 2), | |
"lower-left": (10, 3 * height // 4 - total_height // 2), | |
"lower-left-center": (width // 4 - max_line_width // 2, 3 * height // 4 - total_height // 2), | |
"lower-center": ((width - max_line_width) / 2, 3 * height // 4 - total_height // 2), | |
"lower-right-center": (3 * width // 4 - max_line_width // 2, 3 * height // 4 - total_height // 2), | |
"lower-right": (width - max_line_width - 10, 3 * height // 4 - total_height // 2), | |
"bottom-left": (10, height - total_height - 10), | |
"bottom-left-center": (width // 4 - max_line_width // 2, height - total_height - 10), | |
"bottom-center": ((width - max_line_width) / 2, height - total_height - 10), | |
"bottom-right-center": (3 * width // 4 - max_line_width // 2, height - total_height - 10), | |
"bottom-right": (width - max_line_width - 10, height - total_height - 10), | |
} | |
x, y = position_mapping.get(position, ((width - max_line_width) / 2, (height - total_height) / 2)) | |
for i, line in enumerate(lines): | |
draw.text((x, y), line, fill="black", font=font) | |
y += line_heights[i] | |
return image | |
def infer_canny(prompt, text_for_image, text_position, font_size, | |
negative_prompt = "nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights", | |
seed = 397886929, | |
randomize_seed = False, | |
guidance_scale = 8.0, | |
num_inference_steps = 50, | |
controlnet_conditioning_scale = 0.8, | |
control_guidance_end = 0.9, | |
strength = 1.0 | |
): | |
prompt = translate_korean_to_english(prompt) | |
negative_prompt = translate_korean_to_english(negative_prompt) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
# Generate text image | |
init_image = text_to_image(text_for_image, size=font_size, position=text_position) | |
init_image = resize_image(init_image, MAX_IMAGE_SIZE) | |
pipe = pipe_canny.to("cuda") | |
condi_img = process_canny_condition(init_image) | |
image = pipe( | |
prompt=prompt, | |
image=init_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
control_guidance_end=control_guidance_end, | |
strength=strength, | |
control_image=condi_img, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
).images[0] | |
return image, seed # CANNY 이미지 반환 제거 | |
def update_button_states(selected_position): | |
return [ | |
gr.update(variant="primary") if pos == selected_position else gr.update(variant="secondary") | |
for pos in position_list | |
] | |
position_list = [ | |
"top-left", "top-left-center", "top-center", "top-right-center", "top-right", | |
"upper-left", "upper-left-center", "upper-center", "upper-right-center", "upper-right", | |
"middle-left", "middle-left-center", "middle-center", "middle-right-center", "middle-right", | |
"lower-left", "lower-left-center", "lower-center", "lower-right-center", "lower-right", | |
"bottom-left", "bottom-left-center", "bottom-center", "bottom-right-center", "bottom-right" | |
] | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
.text-position-grid { | |
display: grid; | |
grid-template-columns: repeat(5, 1fr); | |
gap: 2px; | |
margin-bottom: 10px; | |
width: 150px; | |
} | |
.text-position-grid button { | |
aspect-ratio: 1; | |
padding: 0; | |
border: 1px solid #ccc; | |
background-color: #f0f0f0; | |
cursor: pointer; | |
font-size: 10px; | |
transition: all 0.3s ease; | |
} | |
.text-position-grid button:hover { | |
background-color: #e0e0e0; | |
} | |
.text-position-grid button.selected { | |
background-color: #007bff; | |
color: white; | |
transform: scale(1.1); | |
} | |
""" | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as Kolors: | |
text_position = gr.State("middle-center") | |
with gr.Row(): | |
with gr.Column(elem_id="col-left"): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt", | |
lines=2, | |
value="coffee in a cup bokeh --ar 85:128 --v 6.0 --style raw5, 4K, 리얼리티 사진" # Default value added here | |
) | |
with gr.Row(): | |
text_for_image = gr.Textbox( | |
label="Text for Image Generation", | |
placeholder="Enter text to be converted into an image", | |
lines=3, | |
value="대한 萬世 GO" # Default value added here | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("Text Position") | |
with gr.Row(elem_classes="text-position-grid"): | |
position_buttons = [gr.Button("•") for _ in range(25)] | |
for btn, pos in zip(position_buttons, position_list): | |
btn.click(lambda p=pos: p, outputs=text_position) | |
btn.click(update_button_states, inputs=[text_position], outputs=position_buttons) | |
with gr.Column(): | |
font_size = gr.Slider( | |
label="Text Size", | |
minimum=12, | |
maximum=144, | |
step=1, | |
value=72 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative prompt", | |
placeholder="Enter a negative prompt", | |
visible=True, | |
value="nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights" | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=8.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=10, | |
maximum=50, | |
step=1, | |
value=50, | |
) | |
with gr.Row(): | |
controlnet_conditioning_scale = gr.Slider( | |
label="Controlnet Conditioning Scale", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.8, | |
) | |
control_guidance_end = gr.Slider( | |
label="Control Guidance End", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.9, | |
) | |
with gr.Row(): | |
strength = gr.Slider( | |
label="Strength", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
) | |
with gr.Row(): | |
canny_button = gr.Button("Start", elem_id="button") | |
with gr.Column(elem_id="col-right"): | |
result = gr.Image(label="Result", show_label=False) # Gallery에서 Image로 변경 | |
seed_used = gr.Number(label="Seed Used") | |
canny_button.click( | |
fn = infer_canny, | |
inputs = [prompt, text_for_image, text_position, font_size, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength], | |
outputs = [result, seed_used] | |
) | |
# Set initial button states | |
Kolors.load(update_button_states, inputs=[text_position], outputs=position_buttons) | |
Kolors.launch() |