Bobby
mini ui
dfc82b7
raw
history blame
16.8 kB
prod = False
port = 8080
show_options = False
if prod:
port = 8081
# show_options = False
import os
import gc
import random
import time
import gradio as gr
import numpy as np
import imageio
import torch
from PIL import Image
from typing import Tuple
from diffusers import (
ControlNetModel,
DPMSolverMultistepScheduler,
StableDiffusionControlNetPipeline,
)
from preprocess import Preprocessor
MAX_SEED = np.iinfo(np.int32).max
API_KEY = os.environ.get("API_KEY", None)
print("CUDA version:", torch.version.cuda)
print("loading pipe")
compiled = False
# api = HfApi()
import spaces
if gr.NO_RELOAD:
torch.cuda.max_memory_allocated(device="cuda")
preprocessor = Preprocessor()
preprocessor.load("NormalBae")
# Controlnet Normal
model_id = "lllyasviel/control_v11p_sd15_normalbae"
print("initializing controlnet")
controlnet = ControlNetModel.from_pretrained(
model_id,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
).to("cuda")
# Scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5",
solver_order=2,
subfolder="scheduler",
use_karras_sigmas=True,
final_sigmas_type="sigma_min",
algorithm_type="sde-dpmsolver++",
prediction_type="epsilon",
thresholding=False,
denoise_final=True,
device_map="cuda",
)
# Stable Diffusion Pipeline URL
# base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
pipe = StableDiffusionControlNetPipeline.from_single_file(
base_model_url,
# safety_checker=None,
# load_safety_checker=True,
controlnet=controlnet,
scheduler=scheduler,
torch_dtype=torch.float16,
)
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
pipe.to("cuda")
# experimental speedup?
# pipe.compile()
# torch.cuda.empty_cache()
# gc.collect()
print("---------------Loaded controlnet pipeline---------------")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def get_additional_prompt():
prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
# outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
def get_prompt(prompt, additional_prompt):
interior = "design-style interior designed (interior space), captured with a DSLR camera using f/10 aperture, 1/60 sec shutter speed, ISO 400, 20mm focal length, tungsten white balance, (sharp focus), professional photography, high-resolution, 8k, Pulitzer Prize-winning"
default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
default2 = f"professional 3d model {prompt},octane render,highly detailed,volumetric,dramatic lighting,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
randomize = get_additional_prompt()
# nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
# bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
# ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
# shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari"
if prompt == "":
girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2, ahegao2]
prompts_nsfw = [abg, shibari2, ahegao2]
prompt = f"{random.choice(girls)}"
prompt = f"boho chic {interior}"
# print(f"-------------{preset}-------------")
else:
prompt = f"{prompt},{interior}"
# prompt = default2
print(f"{prompt}")
return prompt
style_list = [
{
"name": "None",
"prompt": ""
},
{
"name": "Minimalistic",
"prompt": "Minimalistic"
},
{
"name": "Boho Chic",
"prompt": "boho chic"
},
{
"name": "Saudi Prince Gold",
"prompt": "saudi prince gold",
},
{
"name": "Modern Farmhouse",
"prompt": "modern farmhouse",
},
{
"name": "Neoclassical",
"prompt": "Neoclassical",
},
{
"name": "Eclectic",
"prompt": "Eclectic",
},
{
"name": "Parisian White",
"prompt": "Parisian White",
},
{
"name": "Hollywood Glam",
"prompt": "Hollywood Glam",
},
{
"name": "Scandinavian",
"prompt": "Scandinavian",
},
]
styles = {k["name"]: (k["prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
def apply_style(style_name):
if style_name in styles:
p = styles.get(style_name, "boho chic")
return p
css = """
h1 {
text-align: center;
display:block;
}
h2 {
text-align: center;
display:block;
}
h3 {
text-align: center;
display:block;
}
.gradio-container{max-width: 600px !important}
footer {visibility: hidden}
"""
with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
#############################################################################
with gr.Row():
with gr.Accordion("Advanced options", open=show_options, visible=show_options):
num_images = gr.Slider(
label="Images", minimum=1, maximum=4, value=1, step=1
)
image_resolution = gr.Slider(
label="Image resolution",
minimum=256,
maximum=1024,
value=768,
step=256,
)
preprocess_resolution = gr.Slider(
label="Preprocess resolution",
minimum=128,
maximum=1024,
value=768,
step=1,
)
num_steps = gr.Slider(
label="Number of steps", minimum=1, maximum=100, value=15, step=1
) # 20/4.5 or 12 without lora, 4 with lora
guidance_scale = gr.Slider(
label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
) # 5 without lora, 2 with lora
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
a_prompt = gr.Textbox(
label="Additional prompt",
value = ""
)
n_prompt = gr.Textbox(
label="Negative prompt",
value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
)
#############################################################################
# input text
with gr.Row():
gr.Text(label="Interior Design Style Examples", value="Eclectic, Maximalist, Bohemian, Scandinavian, Minimalist, Rustic, Modern Farmhouse, Contemporary, Luxury, Airbnb, Boho Chic, Midcentury Modern, Art Deco, Zen, Beach, Neoclassical, Industrial, Biophilic, Eco-friendly, Hollywood Glam, Parisian White, Saudi Prince Gold, French Country, Monster Energy Drink, Cyberpunk, Vaporwave, Baroque, etc.\n\nPro tip: add a color to customize it! You can also describe the furniture type.")
with gr.Column():
prompt = gr.Textbox(
label="Description",
placeholder="boho chic",
)
with gr.Row(visible=True):
style_selection = gr.Radio(
show_label=True,
container=True,
interactive=True,
choices=STYLE_NAMES,
value="boho chic",
label="Design Styles",
)
# input image
with gr.Row():
with gr.Column():
image = gr.Image(
label="Input",
sources=["upload"],
show_label=True,
mirror_webcam=True,
format="webp",
)
# run button
with gr.Column():
run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
# output image
with gr.Column():
result = gr.Image(
label="Output",
interactive=False,
format="webp",
show_share_button= False,
)
# Use this image button
with gr.Column():
use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
config = [
image,
style_selection,
prompt,
a_prompt,
n_prompt,
num_images,
image_resolution,
preprocess_resolution,
num_steps,
guidance_scale,
seed,
]
with gr.Row():
helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
# image processing
@gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
# AI Image Processing
@gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
def submit(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
# Change input to result
@gr.on(triggers=[use_ai_button.click], inputs=None, outputs=image, show_progress="hidden")
def update_input():
try:
print("Updating image to AI Temp Image")
ai_temp_image = Image.open("temp_image.jpg")
return ai_temp_image
except FileNotFoundError:
print("No AI Image Available")
return None
# Turn off buttons when processing
@gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
def turn_buttons_off():
return gr.update(visible=False), gr.update(visible=False)
# Turn on buttons when processing is complete
@gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
def turn_buttons_on():
return gr.update(visible=True), gr.update(visible=True)
@spaces.GPU(duration=12)
@torch.inference_mode()
def process_image(
image,
style_selection,
prompt,
a_prompt,
n_prompt,
num_images,
image_resolution,
preprocess_resolution,
num_steps,
guidance_scale,
seed,
progress=gr.Progress(track_tqdm=True)
):
print("processing image")
start = time.time()
preprocessor.load("NormalBae")
# preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
control_image = preprocessor(
image=image,
image_resolution=image_resolution,
detect_resolution=preprocess_resolution,
)
if style_selection != "None" or style_selection != None:
prompt = apply_style(style_selection)
custom_prompt=str(get_prompt(prompt, a_prompt))
negative_prompt=str(n_prompt)
global compiled
seed = random.randint(0, MAX_SEED)
generator = torch.cuda.manual_seed(seed)
if not compiled:
print("Not Compiled")
compiled = True
results = pipe(
prompt=custom_prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
num_inference_steps=num_steps,
generator=generator,
image=control_image,
).images[0]
print(f"\n-------------------------Processed in: {time.time() - start:.2f} seconds-------------------------")
timestamp = int(time.time())
#if not os.path.exists("./outputs"):
# os.makedirs("./outputs")
img_path = f"./{timestamp}.jpg"
results_path = f"./{timestamp}_out_{prompt}.jpg"
# imageio.imsave(img_path, image)
results.save(results_path)
results.save("temp_image.jpg")
# api.upload_file(
# path_or_fileobj=img_path,
# path_in_repo=img_path,
# repo_id="broyang/anime-ai-outputs",
# repo_type="dataset",
# token=API_KEY,
# run_as_future=True,
# )
# api.upload_file(
# path_or_fileobj=results_path,
# path_in_repo=results_path,
# repo_id="broyang/anime-ai-outputs",
# repo_type="dataset",
# token=API_KEY,
# run_as_future=True,
# )
return results
if prod:
demo.queue(max_size=20).launch(server_name="localhost", server_port=port)
else:
demo.queue(api_open=False).launch(show_api=False)