ming5468's picture
Update app.py
8387ae4 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance
import os
from typing import List, Tuple, Optional
import json
import spaces
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
import random
import time
# LoRA 카테고리 및 옵션 정의
LORA_CATEGORIES = {
"custom": {
"name": "커스텀",
"icon": "🎭",
"options": {
"mgtoco": {
"name": "MGTOCO",
"prompt": "MGTOCO style, high quality, detailed",
"image": "https://huggingface.co/ming5468/mgtonecoollora/resolve/main/preview.jpg",
"repo": "ming5468/mgtonecoollora",
"weights": "lora.safetensors",
"trigger_word": "MGTOCO",
"trigger_position": "prepend"
# },
#"polaroid": {
# "name": "Polaroid Style",
# "prompt": "polaroid style, vintage photo",
# "image": "https://huggingface.co/alvdansen/pola-photo-flux/resolve/main/images/out-2%20(83).webp",
# "repo": "alvdansen/pola-photo-flux",
# "trigger_word": ", polaroid style",
# "trigger_position": "append"
}
}
},
"brightness": {
"name": "밝기",
"icon": "💡",
"options": {
"bright": {"name": "밝은", "prompt": "bright, high key lighting", "image": "https://via.placeholder.com/300x300/FFD700/000000?text=Bright"},
"dark": {"name": "어두운", "prompt": "dark, low key lighting, moody", "image": "https://via.placeholder.com/300x300/2C3E50/FFFFFF?text=Dark"},
"mbrightness": {"name": "중간밝기", "prompt": "balanced lighting, natural brightness", "image": "https://via.placeholder.com/300x300/95A5A6/FFFFFF?text=Medium"}
}
},
"saturation": {
"name": "채도",
"icon": "🎨",
"options": {
"vivid": {"name": "고채도", "prompt": "vivid colors, high saturation, vibrant", "image": "https://via.placeholder.com/300x300/E74C3C/FFFFFF?text=Vivid"},
"muted": {"name": "저채도", "prompt": "muted colors, desaturated, pastel tones", "image": "https://via.placeholder.com/300x300/BDC3C7/000000?text=Muted"},
"nsaturation": {"name": "중채도", "prompt": "natural color saturation, balanced tones", "image": "https://via.placeholder.com/300x300/3498DB/FFFFFF?text=Natural"}
}
},
"position": {
"name": "구도",
"icon": "📐",
"options": {
"center": {"name": "중앙집중", "prompt": "centered composition, rule of thirds", "image": "https://via.placeholder.com/300x300/9B59B6/FFFFFF?text=Center"},
"sym": {"name": "대칭구도", "prompt": "symmetrical composition, balanced frame", "image": "https://via.placeholder.com/300x300/8E44AD/FFFFFF?text=Symmetric"},
"asym": {"name": "비대칭구도", "prompt": "asymmetrical composition, dynamic balance", "image": "https://via.placeholder.com/300x300/A569BD/FFFFFF?text=Asymmetric"},
"dynamic": {"name": "역동적구도", "prompt": "dynamic composition, diagonal lines, movement", "image": "https://via.placeholder.com/300x300/7D3C98/FFFFFF?text=Dynamic"}
}
},
"angle": {
"name": "앵글",
"icon": "📷",
"options": {
"front": {"name": "정면", "prompt": "frontal view, straight on angle", "image": "https://via.placeholder.com/300x300/1ABC9C/FFFFFF?text=Front"},
"high": {"name": "하이앵글", "prompt": "high angle shot, bird's eye view", "image": "https://via.placeholder.com/300x300/16A085/FFFFFF?text=High+Angle"},
"low": {"name": "로우앵글", "prompt": "low angle shot, worm's eye view", "image": "https://via.placeholder.com/300x300/48C9B0/000000?text=Low+Angle"},
"side": {"name": "측면", "prompt": "side view, profile angle", "image": "https://via.placeholder.com/300x300/52D6C4/000000?text=Side+View"}
}
},
"effect": {
"name": "효과",
"icon": "✨",
"options": {
"clear": {"name": "선명한", "prompt": "sharp, crystal clear, high detail", "image": "https://via.placeholder.com/300x300/F39C12/FFFFFF?text=Clear"},
"blur": {"name": "블러", "prompt": "soft blur, dreamy effect, bokeh", "image": "https://via.placeholder.com/300x300/E67E22/FFFFFF?text=Blur"},
"shiny": {"name": "빛반사", "prompt": "glossy, reflective surface, rim lighting", "image": "https://via.placeholder.com/300x300/F7DC6F/000000?text=Shiny"},
"strong": {"name": "고대비", "prompt": "high contrast, dramatic lighting", "image": "https://via.placeholder.com/300x300/D35400/FFFFFF?text=High+Contrast"},
"soft": {"name": "저대비", "prompt": "low contrast, soft lighting, gentle tones", "image": "https://via.placeholder.com/300x300/F8C471/000000?text=Soft"}
}
},
"mood": {
"name": "무드",
"icon": "💫",
"options": {
"bmood": {"name": "밝은무드", "prompt": "cheerful mood, uplifting atmosphere", "image": "https://via.placeholder.com/300x300/F1C40F/000000?text=Bright+Mood"},
"dmood": {"name": "어두운무드", "prompt": "dark mood, mysterious atmosphere", "image": "https://via.placeholder.com/300x300/34495E/FFFFFF?text=Dark+Mood"},
"dreamy": {"name": "몽환적", "prompt": "dreamy, ethereal, surreal atmosphere", "image": "https://via.placeholder.com/300x300/BB8FCE/000000?text=Dreamy"},
"natural": {"name": "자연스러운", "prompt": "natural mood, organic feel, authentic", "image": "https://via.placeholder.com/300x300/58D68D/000000?text=Natural"}
}
}
}
# 모델 초기화
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
# VAE 및 파이프라인 로드
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
# Image-to-Image 파이프라인 추가
from diffusers import AutoPipelineForImage2Image
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
base_model,
vae=good_vae,
transformer=pipe.transformer,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
text_encoder_2=pipe.text_encoder_2,
tokenizer_2=pipe.tokenizer_2,
torch_dtype=dtype
)
MAX_SEED = 2**32-1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
class LoRAStudioApp:
def __init__(self):
self.selected_loras = []
self.current_prompt = ""
self.current_category = "custom" # 커스텀 카테고리를 기본으로
self.uploaded_image = None
self.generated_image = None
self.loaded_lora_repo = None # 현재 로드된 LoRA 저장
def get_models_by_category(self, category_id):
"""특정 카테고리의 모델들만 반환"""
if category_id not in LORA_CATEGORIES:
return []
models = []
category = LORA_CATEGORIES[category_id]
for option_id, option in category["options"].items():
models.append({
"id": f"{category_id}-{option_id}",
"name": option["name"],
"prompt": option["prompt"],
"image": option["image"],
"category": category["name"]
})
return models
def get_gallery_data(self, category_id):
"""Gallery 컴포넌트용 데이터 반환"""
models = self.get_models_by_category(category_id)
gallery_data = []
for model in models:
is_selected = model["id"] in self.selected_loras
title = f"{'✅ ' if is_selected else ''}{model['name']}"
gallery_data.append((model["image"], title))
return gallery_data
def select_lora_model(self, model_id):
"""LoRA 모델 선택/해제"""
if model_id in self.selected_loras:
self.selected_loras.remove(model_id)
else:
self.selected_loras.append(model_id)
self.update_prompt()
return self.get_prompt_display()
def update_prompt(self):
"""선택된 LoRA들로 프롬프트 업데이트"""
prompts = []
for lora_id in self.selected_loras:
category_id, option_id = lora_id.split('-', 1)
if category_id in LORA_CATEGORIES and option_id in LORA_CATEGORIES[category_id]["options"]:
option = LORA_CATEGORIES[category_id]["options"][option_id]
prompts.append(option["prompt"])
# LoRA repo 정보 저장 (실제 로드용)
if "repo" in option:
self.loaded_lora_repo = option["repo"]
self.loaded_lora_weights = option.get("weights", None) # weights 파일명 저장
self.loaded_lora_trigger = option.get("trigger_word", "")
self.loaded_lora_trigger_pos = option.get("trigger_position", "append")
self.current_prompt = ", ".join(prompts)
def get_prompt_display(self):
"""프롬프트 표시용 텍스트 반환"""
if not self.current_prompt:
return "LoRA를 선택하면 자동으로 프롬프트가 생성됩니다"
return self.current_prompt
def upload_image(self, image):
"""이미지 업로드 처리"""
if image is not None:
self.uploaded_image = image
return image, "이미지가 업로드되었습니다. Generate 버튼을 클릭하세요."
return None, "이미지를 업로드해주세요."
@spaces.GPU(duration=70)
def generate_image(self, custom_prompt="", steps=28, cfg_scale=3.5, seed=None, image_strength=0.75):
"""이미지 생성 (FLUX Image-to-Image 파이프라인 사용)"""
if self.uploaded_image is None:
return None, "먼저 이미지를 업로드해주세요."
# steps를 정수로 명시적 변환 및 검증
steps = int(steps) if steps else 28
if steps <= 0:
steps = 28
print(f"Steps 값: {steps} (타입: {type(steps)})")
# 업로드된 이미지 크기를 그대로 사용
width, height = self.uploaded_image.size
print(f"이미지 크기: {width}x{height} (원본 크기 유지)")
# 프롬프트 구성
final_prompt = self.current_prompt
if custom_prompt.strip():
final_prompt = f"{self.current_prompt}, {custom_prompt}" if self.current_prompt else custom_prompt
# 트리거 워드 처리
if hasattr(self, 'loaded_lora_trigger') and self.loaded_lora_trigger:
if self.loaded_lora_trigger_pos == "prepend":
final_prompt = f"{self.loaded_lora_trigger} {final_prompt}"
else:
final_prompt = f"{final_prompt} {self.loaded_lora_trigger}"
if seed is None:
seed = random.randint(0, MAX_SEED)
# LoRA 로드 (repo 정보가 있는 경우)
if hasattr(self, 'loaded_lora_repo') and self.loaded_lora_repo:
try:
pipe_i2i.unload_lora_weights()
# weights 파일명 지정
weight_name = getattr(self, 'loaded_lora_weights', None)
# Replicate 학습 LoRA 호환을 위해 adapter_name 지정
pipe_i2i.load_lora_weights(
self.loaded_lora_repo,
weight_name=weight_name,
low_cpu_mem_usage=True,
adapter_name="flux_lora"
)
# LoRA 스케일 설정
pipe_i2i.set_adapters("flux_lora", adapter_weights=[0.95])
print(f"LoRA 로드 성공: {self.loaded_lora_repo} (weights: {weight_name})")
except Exception as e:
print(f"LoRA 로드 실패: {e}")
# 실패해도 계속 진행 (베이스 모델로 생성)
pass
else:
print("로드할 LoRA repo가 없습니다")
generator = torch.Generator(device="cuda").manual_seed(seed)
# Image-to-Image 파이프라인으로 이미지 변형
try:
self.generated_image = pipe_i2i(
prompt=final_prompt,
image=self.uploaded_image, # 원본 이미지 그대로 사용
strength=float(image_strength),
num_inference_steps=int(steps),
guidance_scale=float(cfg_scale),
generator=generator,
output_type="pil",
).images[0]
return self.generated_image, f"생성 완료!\n사용된 프롬프트: {final_prompt}\nSeed: {seed}\nStrength: {image_strength}\n크기: {width}x{height}\nSteps: {steps}"
except Exception as e:
print(f"이미지 생성 실패: {e}")
import traceback
traceback.print_exc()
return None, f"이미지 생성 실패: {str(e)}"
def apply_post_processing(self, skin_detail, eye_detail, blur_strength, glitch_strength, distortion_strength):
"""후보정 적용"""
if self.generated_image is None:
return None
result = self.generated_image.copy()
if blur_strength > 0:
result = result.filter(ImageFilter.GaussianBlur(radius=blur_strength/20))
return result
# 앱 인스턴스
app = LoRAStudioApp()
def create_main_interface():
"""메인 인터페이스 생성"""
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="slate",
secondary_hue="gray",
neutral_hue="stone"
),
css_paths=["style.css"],
title="LoRA Studio"
) as demo:
# 상단 헤더
with gr.Row(elem_classes="main-header"):
with gr.Column(scale=1):
gr.Markdown("# LoRA Studio")
with gr.Column(scale=3):
pass
with gr.Column(scale=1):
upload_nav_btn = gr.Button("📤 Upload", variant="primary", elem_classes="upload-btn")
with gr.Row(elem_classes="main-content"):
# 왼쪽 사이드바 (1/5)
with gr.Column(scale=1, elem_classes="sidebar"):
models_btn = gr.Button("📚 Models", elem_classes="nav-button active")
gallery_btn = gr.Button("🖼️ Gallery", elem_classes="nav-button")
generate_btn = gr.Button("⚡ Generate", elem_classes="nav-button")
help_btn = gr.Button("❓ Help", elem_classes="nav-button")
feedback_btn = gr.Button("💭 Feedback", elem_classes="nav-button")
# 메인 콘텐츠 영역 (4/5)
with gr.Column(scale=4, elem_classes="content-area"):
# Models 페이지
with gr.Group(visible=True) as models_page:
# 카테고리 버튼들
with gr.Row(elem_classes="category-filters"):
category_buttons = {}
for category_id, category in LORA_CATEGORIES.items():
btn = gr.Button(
f"{category['icon']} {category['name']}",
elem_classes=f"category-btn {'selected' if category_id == 'tone' else ''}",
elem_id=f"cat-{category_id}"
)
category_buttons[category_id] = btn
# 프롬프트 표시
prompt_display = gr.Textbox(
label="Generated Prompt Preview",
value="LoRA를 선택하면 자동으로 프롬프트가 생성됩니다",
lines=3,
interactive=False,
elem_classes="prompt-display"
)
# 모델 갤러리
model_gallery = gr.Gallery(
value=app.get_gallery_data("tone"),
label="LoRA Models",
show_label=False,
elem_id="model-gallery",
columns=4,
rows=2,
height="auto",
allow_preview=False,
show_share_button=False
)
# Upload 페이지
with gr.Group(visible=False) as upload_page:
gr.Markdown("## 📤 이미지 업로드")
upload_image_comp = gr.Image(
label="이미지 업로드",
type="pil",
elem_classes="upload-area"
)
upload_info = gr.Textbox(
label="업로드 상태",
value="이미지를 업로드해주세요",
interactive=False
)
with gr.Row(elem_classes="button-row"):
back_to_models_btn = gr.Button("← Models로 돌아가기", variant="secondary")
proceed_to_generate_btn = gr.Button("Generate로 이동 →", variant="primary")
# Generate 페이지
with gr.Group(visible=False) as generate_page:
gr.Markdown("## ⚡ 이미지 생성")
with gr.Row():
with gr.Column():
gr.Markdown("### 원본 이미지")
original_preview = gr.Image(label="Original", interactive=False)
with gr.Column():
gr.Markdown("### 설정")
selected_loras_display = gr.Textbox(
label="선택된 LoRA",
value="선택된 LoRA가 없습니다",
lines=3,
interactive=False
)
custom_prompt = gr.Textbox(
label="추가 프롬프트 (선택사항)",
placeholder="추가 프롬프트를 입력하세요",
lines=2
)
steps_slider = gr.Slider(
label="생성 스텝 수",
minimum=10,
maximum=50,
step=1,
value=28
)
image_strength = gr.Slider(
label="이미지 변형 강도 (낮을수록 원본 유지)",
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.75
)
generate_image_btn = gr.Button("🪄 Generate Image", variant="primary", size="lg")
generated_result = gr.Image(label="Generated Result", interactive=False)
generation_info = gr.Textbox(label="Generation Info", interactive=False, lines=2)
with gr.Row(elem_classes="button-row"):
back_to_upload_btn = gr.Button("← Upload로 돌아가기", variant="secondary")
proceed_to_post_btn = gr.Button("후보정으로 이동 →", variant="primary")
# 후보정 페이지
with gr.Group(visible=False) as postprocess_page:
gr.Markdown("## ✨ 후보정 패널")
with gr.Row():
with gr.Column():
gr.Markdown("### 결과 미리보기")
final_result = gr.Image(label="Final Result", interactive=False)
with gr.Row(elem_classes="button-row"):
save_btn = gr.Button("💾 Save", variant="primary")
download_btn = gr.Button("📥 Download", variant="secondary")
with gr.Column(elem_classes="postprocess-panel"):
gr.Markdown("### 피부결 디테일 보정")
skin_detail = gr.Slider(0, 100, 50, label="모공 & 솜털", step=1)
eye_detail = gr.Slider(0, 100, 50, label="눈동자 디테일", step=1)
gr.Markdown("### 무드 효과 강도")
blur_strength = gr.Slider(0, 100, 0, label="Blur (%)", step=1)
glitch_strength = gr.Slider(0, 100, 0, label="Glitch (%)", step=1)
distortion_strength = gr.Slider(0, 100, 0, label="Distortion (%)", step=1)
gr.Markdown("### 컷별 적용")
with gr.Row():
apply_all_btn = gr.Button("전체 적용", variant="primary")
apply_selection_btn = gr.Button("선택 영역만", variant="secondary")
with gr.Row(elem_classes="button-row"):
back_to_generate_btn = gr.Button("← Generate로 돌아가기", variant="secondary")
restart_btn = gr.Button("🔄 처음부터 다시", variant="stop")
# 페이지 전환 함수들
def show_models():
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
def show_upload():
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
def show_generate():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
def show_postprocess():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
# 카테고리 필터 함수
def filter_by_category(category_id):
app.current_category = category_id
return app.get_gallery_data(category_id)
# 모델 선택 함수
def select_model_by_click(evt: gr.SelectData):
"""갤러리에서 모델 선택 시 호출"""
models = app.get_models_by_category(app.current_category)
if evt.index < len(models):
model_id = models[evt.index]["id"]
app.select_lora_model(model_id)
selected_text = f"✅ 선택됨: {models[evt.index]['name']}"
return (app.get_gallery_data(app.current_category),
app.get_prompt_display(),
selected_text)
return gr.update(), gr.update(), "선택된 LoRA가 없습니다"
# 이벤트 바인딩
models_btn.click(show_models, outputs=[models_page, upload_page, generate_page, postprocess_page])
upload_nav_btn.click(show_upload, outputs=[models_page, upload_page, generate_page, postprocess_page])
# 카테고리 버튼 이벤트
for category_id, btn in category_buttons.items():
btn.click(
lambda cat=category_id: filter_by_category(cat),
outputs=[model_gallery]
)
# 모델 갤러리 클릭 이벤트
model_gallery.select(
select_model_by_click,
outputs=[model_gallery, prompt_display, selected_loras_display]
)
# 페이지 네비게이션
back_to_models_btn.click(show_models, outputs=[models_page, upload_page, generate_page, postprocess_page])
proceed_to_generate_btn.click(show_generate, outputs=[models_page, upload_page, generate_page, postprocess_page])
back_to_upload_btn.click(show_upload, outputs=[models_page, upload_page, generate_page, postprocess_page])
proceed_to_post_btn.click(show_postprocess, outputs=[models_page, upload_page, generate_page, postprocess_page])
back_to_generate_btn.click(show_generate, outputs=[models_page, upload_page, generate_page, postprocess_page])
restart_btn.click(show_models, outputs=[models_page, upload_page, generate_page, postprocess_page])
# 이미지 처리
upload_image_comp.change(app.upload_image, inputs=[upload_image_comp], outputs=[original_preview, upload_info])
generate_image_btn.click(
app.generate_image,
inputs=[custom_prompt, steps_slider, image_strength],
outputs=[generated_result, generation_info]
)
# 후보정
for slider in [skin_detail, eye_detail, blur_strength, glitch_strength, distortion_strength]:
slider.change(
lambda sd=skin_detail, ed=eye_detail, bs=blur_strength, gs=glitch_strength, ds=distortion_strength:
app.apply_post_processing(sd, ed, bs, gs, ds),
outputs=[final_result]
)
return demo
if __name__ == "__main__":
demo = create_main_interface()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)