import spaces import random import torch import cv2 import gradio as gr import numpy as np from huggingface_hub import snapshot_download from transformers import pipeline from diffusers.utils import load_image from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline from kolors.models.modeling_chatglm import ChatGLMModel from kolors.models.tokenization_chatglm import ChatGLMTokenizer from kolors.models.controlnet import ControlNetModel from diffusers import AutoencoderKL from kolors.models.unet_2d_condition import UNet2DConditionModel from diffusers import EulerDiscreteScheduler from PIL import Image, ImageDraw, ImageFont import os device = "cuda" ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors") ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny") # Add translation pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device) tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device) scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device) controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device) pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline( vae=vae, controlnet=controlnet_canny, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=False ) @spaces.GPU def translate_korean_to_english(text): if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): # Check if Korean characters are present translated = translator(text, max_length=512)[0]['translation_text'] return translated return text def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: x = x[:, :, None] assert x.ndim == 3 H, W, C = x.shape assert C == 1 or C == 3 or C == 4 if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) if C == 4: color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y @spaces.GPU def process_canny_condition(image, canny_threods=[100,200]): np_image = np.array(image) np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1]) np_image = np_image[:, :, None] np_image = np.concatenate([np_image, np_image, np_image], axis=2) np_image = HWC3(np_image) return Image.fromarray(np_image) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 def resize_image(image, resolution): w, h = image.size ratio = resolution / max(w, h) new_w = int(w * ratio) new_h = int(h * ratio) return image.resize((new_w, new_h), Image.LANCZOS) def text_to_image(text, size=72, position="middle-center"): width, height = 1024, 576 image = Image.new("RGB", (width, height), "white") draw = ImageDraw.Draw(image) font_files = ["Arial_Unicode.ttf"] font = None for font_file in font_files: font_path = os.path.join(os.path.dirname(__file__), font_file) if os.path.exists(font_path): try: font = ImageFont.truetype(font_path, size=size) print(f"Using font: {font_file}") break except IOError: print(f"Error loading font: {font_file}") if font is None: print("No suitable font found. Using default font.") font = ImageFont.load_default() lines = text.split('\n') max_line_width = 0 total_height = 0 line_heights = [] for line in lines: left, top, right, bottom = draw.textbbox((0, 0), line, font=font) line_width = right - left line_height = bottom - top line_heights.append(line_height) max_line_width = max(max_line_width, line_width) total_height += line_height position_mapping = { "top-left": (10, 10), "top-left-center": (width // 4 - max_line_width // 2, 10), "top-center": ((width - max_line_width) / 2, 10), "top-right-center": (3 * width // 4 - max_line_width // 2, 10), "top-right": (width - max_line_width - 10, 10), "upper-left": (10, height // 4 - total_height // 2), "upper-left-center": (width // 4 - max_line_width // 2, height // 4 - total_height // 2), "upper-center": ((width - max_line_width) / 2, height // 4 - total_height // 2), "upper-right-center": (3 * width // 4 - max_line_width // 2, height // 4 - total_height // 2), "upper-right": (width - max_line_width - 10, height // 4 - total_height // 2), "middle-left": (10, (height - total_height) / 2), "middle-left-center": (width // 4 - max_line_width // 2, (height - total_height) / 2), "middle-center": ((width - max_line_width) / 2, (height - total_height) / 2), "middle-right-center": (3 * width // 4 - max_line_width // 2, (height - total_height) / 2), "middle-right": (width - max_line_width - 10, (height - total_height) / 2), "lower-left": (10, 3 * height // 4 - total_height // 2), "lower-left-center": (width // 4 - max_line_width // 2, 3 * height // 4 - total_height // 2), "lower-center": ((width - max_line_width) / 2, 3 * height // 4 - total_height // 2), "lower-right-center": (3 * width // 4 - max_line_width // 2, 3 * height // 4 - total_height // 2), "lower-right": (width - max_line_width - 10, 3 * height // 4 - total_height // 2), "bottom-left": (10, height - total_height - 10), "bottom-left-center": (width // 4 - max_line_width // 2, height - total_height - 10), "bottom-center": ((width - max_line_width) / 2, height - total_height - 10), "bottom-right-center": (3 * width // 4 - max_line_width // 2, height - total_height - 10), "bottom-right": (width - max_line_width - 10, height - total_height - 10), } x, y = position_mapping.get(position, ((width - max_line_width) / 2, (height - total_height) / 2)) for i, line in enumerate(lines): draw.text((x, y), line, fill="black", font=font) y += line_heights[i] return image @spaces.GPU def infer_canny(prompt, text_for_image, text_position, font_size, negative_prompt = "nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights", seed = 397886929, randomize_seed = False, guidance_scale = 8.0, num_inference_steps = 50, controlnet_conditioning_scale = 0.8, control_guidance_end = 0.9, strength = 1.0 ): prompt = translate_korean_to_english(prompt) negative_prompt = translate_korean_to_english(negative_prompt) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) # Generate text image init_image = text_to_image(text_for_image, size=font_size, position=text_position) init_image = resize_image(init_image, MAX_IMAGE_SIZE) pipe = pipe_canny.to("cuda") condi_img = process_canny_condition(init_image) image = pipe( prompt=prompt, image=init_image, controlnet_conditioning_scale=controlnet_conditioning_scale, control_guidance_end=control_guidance_end, strength=strength, control_image=condi_img, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, generator=generator, ).images[0] return image, seed # CANNY 이미지 반환 제거 def update_button_states(selected_position): return [ gr.update(variant="primary") if pos == selected_position else gr.update(variant="secondary") for pos in position_list ] position_list = [ "top-left", "top-left-center", "top-center", "top-right-center", "top-right", "upper-left", "upper-left-center", "upper-center", "upper-right-center", "upper-right", "middle-left", "middle-left-center", "middle-center", "middle-right-center", "middle-right", "lower-left", "lower-left-center", "lower-center", "lower-right-center", "lower-right", "bottom-left", "bottom-left-center", "bottom-center", "bottom-right-center", "bottom-right" ] css = """ footer { visibility: hidden; } .text-position-grid { display: grid; grid-template-columns: repeat(5, 1fr); gap: 2px; margin-bottom: 10px; width: 150px; } .text-position-grid button { aspect-ratio: 1; padding: 0; border: 1px solid #ccc; background-color: #f0f0f0; cursor: pointer; font-size: 10px; transition: all 0.3s ease; } .text-position-grid button:hover { background-color: #e0e0e0; } .text-position-grid button.selected { background-color: #007bff; color: white; transform: scale(1.1); } """ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as Kolors: text_position = gr.State("middle-center") with gr.Row(): with gr.Column(elem_id="col-left"): with gr.Row(): prompt = gr.Textbox( label="Prompt", placeholder="Enter your prompt", lines=2, value="coffee in a cup bokeh --ar 85:128 --v 6.0 --style raw5, 4K, 리얼리티 사진" # Default value added here ) with gr.Row(): text_for_image = gr.Textbox( label="Text for Image Generation", placeholder="Enter text to be converted into an image", lines=3, value="대한 萬世 GO" # Default value added here ) with gr.Row(): with gr.Column(): gr.Markdown("Text Position") with gr.Row(elem_classes="text-position-grid"): position_buttons = [gr.Button("•") for _ in range(25)] for btn, pos in zip(position_buttons, position_list): btn.click(lambda p=pos: p, outputs=text_position) btn.click(update_button_states, inputs=[text_position], outputs=position_buttons) with gr.Column(): font_size = gr.Slider( label="Text Size", minimum=12, maximum=144, step=1, value=72 ) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Textbox( label="Negative prompt", placeholder="Enter a negative prompt", visible=True, value="nsfw, facial shadows, low resolution, jpeg artifacts, blurry, bad quality, dark face, neon lights" ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=8.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=10, maximum=50, step=1, value=50, ) with gr.Row(): controlnet_conditioning_scale = gr.Slider( label="Controlnet Conditioning Scale", minimum=0.0, maximum=1.0, step=0.1, value=0.8, ) control_guidance_end = gr.Slider( label="Control Guidance End", minimum=0.0, maximum=1.0, step=0.1, value=0.9, ) with gr.Row(): strength = gr.Slider( label="Strength", minimum=0.0, maximum=1.0, step=0.1, value=1.0, ) with gr.Row(): canny_button = gr.Button("Start", elem_id="button") with gr.Column(elem_id="col-right"): result = gr.Image(label="Result", show_label=False) # Gallery에서 Image로 변경 seed_used = gr.Number(label="Seed Used") canny_button.click( fn = infer_canny, inputs = [prompt, text_for_image, text_position, font_size, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength], outputs = [result, seed_used] ) # Set initial button states Kolors.load(update_button_states, inputs=[text_position], outputs=position_buttons) Kolors.launch()