Commit
·
c9de947
1
Parent(s):
37393f2
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
| 4 |
import argparse
|
| 5 |
import shutil
|
| 6 |
from train_dreambooth import run_training
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
import torch
|
| 9 |
|
|
@@ -47,6 +48,8 @@ def swap_text(option):
|
|
| 47 |
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
|
| 48 |
|
| 49 |
def train(*inputs):
|
|
|
|
|
|
|
| 50 |
file_counter = 0
|
| 51 |
for i, input in enumerate(inputs):
|
| 52 |
if(i < maximum_concepts-1):
|
|
@@ -156,12 +159,23 @@ def train(*inputs):
|
|
| 156 |
max_train_steps=Training_Steps,
|
| 157 |
)
|
| 158 |
run_training(args_general)
|
| 159 |
-
|
| 160 |
shutil.rmtree('instance_images')
|
| 161 |
-
shutil.make_archive("
|
| 162 |
shutil.rmtree("output_model")
|
| 163 |
torch.cuda.empty_cache()
|
| 164 |
-
return [gr.update(visible=True, value="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
with gr.Blocks(css=css) as demo:
|
| 167 |
with gr.Box():
|
|
@@ -252,4 +266,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 252 |
push_button = gr.Button("Push to the Hub")
|
| 253 |
result = gr.File(label="Download the uploaded models (zip file are diffusers weights, *.ckpt are CompVis/AUTOMATIC1111 weights)", visible=True)
|
| 254 |
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
|
|
|
|
|
|
|
| 255 |
demo.launch()
|
|
|
|
| 4 |
import argparse
|
| 5 |
import shutil
|
| 6 |
from train_dreambooth import run_training
|
| 7 |
+
from converttosd import convert
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
|
|
|
|
| 48 |
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
|
| 49 |
|
| 50 |
def train(*inputs):
|
| 51 |
+
if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
|
| 52 |
+
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
| 53 |
file_counter = 0
|
| 54 |
for i, input in enumerate(inputs):
|
| 55 |
if(i < maximum_concepts-1):
|
|
|
|
| 159 |
max_train_steps=Training_Steps,
|
| 160 |
)
|
| 161 |
run_training(args_general)
|
| 162 |
+
convert("output_model", "model.ckpt")
|
| 163 |
shutil.rmtree('instance_images')
|
| 164 |
+
shutil.make_archive("diffusers_model", 'zip', "output_model")
|
| 165 |
shutil.rmtree("output_model")
|
| 166 |
torch.cuda.empty_cache()
|
| 167 |
+
return [gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"]), gr.update(visible=True), gr.update(visible=True)]
|
| 168 |
+
|
| 169 |
+
def generate(prompt):
|
| 170 |
+
from diffusers import StableDiffusionPipeline
|
| 171 |
+
|
| 172 |
+
pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
|
| 173 |
+
pipe = pipe.to("cuda")
|
| 174 |
+
image = pipe(prompt).images[0]
|
| 175 |
+
return(image)
|
| 176 |
+
|
| 177 |
+
def push_button(path):
|
| 178 |
+
pass
|
| 179 |
|
| 180 |
with gr.Blocks(css=css) as demo:
|
| 181 |
with gr.Box():
|
|
|
|
| 266 |
push_button = gr.Button("Push to the Hub")
|
| 267 |
result = gr.File(label="Download the uploaded models (zip file are diffusers weights, *.ckpt are CompVis/AUTOMATIC1111 weights)", visible=True)
|
| 268 |
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
|
| 269 |
+
generate_button.click(fn=generate, inputs=prompt, outputs=result)
|
| 270 |
+
push_button.click(fn=push_to_hub, inputs=model_repo_tag, outputs=[])
|
| 271 |
demo.launch()
|