Commit
•
7967a47
1
Parent(s):
cb9bf15
Remove GPU attribution if CUDA error
Browse files
app.py
CHANGED
@@ -35,15 +35,8 @@ else:
|
|
35 |
is_shared_ui = False
|
36 |
is_gpu_associated = torch.cuda.is_available()
|
37 |
|
38 |
-
|
39 |
-
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
|
40 |
-
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
|
41 |
-
#component-4, #component-3, #component-10{min-height: 0}
|
42 |
-
.duplicate-button img{margin: 0}
|
43 |
-
'''
|
44 |
-
maximum_concepts = 3
|
45 |
|
46 |
-
#Pre download the files
|
47 |
if(is_gpu_associated):
|
48 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
49 |
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-1", ignore_patterns=["*.ckpt", "*.safetensors"])
|
@@ -51,8 +44,25 @@ if(is_gpu_associated):
|
|
51 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
52 |
model_to_load = model_v1
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
def swap_text(option, base):
|
58 |
resize_width = 768 if base == "v2-1-768" else 512
|
@@ -60,7 +70,7 @@ def swap_text(option, base):
|
|
60 |
if(option == "object"):
|
61 |
instance_prompt_example = "cttoy"
|
62 |
freeze_for = 30
|
63 |
-
return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file
|
64 |
elif(option == "person"):
|
65 |
instance_prompt_example = "julcto"
|
66 |
freeze_for = 70
|
@@ -70,27 +80,17 @@ def swap_text(option, base):
|
|
70 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation)
|
71 |
else:
|
72 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation, value=False)
|
73 |
-
return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file
|
74 |
elif(option == "style"):
|
75 |
instance_prompt_example = "trsldamrl"
|
76 |
freeze_for = 10
|
77 |
-
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file
|
78 |
-
|
79 |
-
def swap_base_model(selected_model):
|
80 |
-
if(is_gpu_associated):
|
81 |
-
global model_to_load
|
82 |
-
if(selected_model == "v1-5"):
|
83 |
-
model_to_load = model_v1
|
84 |
-
elif(selected_model == "v2-1-768"):
|
85 |
-
model_to_load = model_v2
|
86 |
-
else:
|
87 |
-
model_to_load = model_v2_512
|
88 |
|
89 |
def count_files(*inputs):
|
90 |
file_counter = 0
|
91 |
concept_counter = 0
|
92 |
for i, input in enumerate(inputs):
|
93 |
-
if(i < maximum_concepts
|
94 |
files = inputs[i]
|
95 |
if(files):
|
96 |
concept_counter+=1
|
@@ -133,6 +133,9 @@ def update_steps(*files_list):
|
|
133 |
file_counter+=len(files)
|
134 |
return(gr.update(value=file_counter*200))
|
135 |
|
|
|
|
|
|
|
136 |
def pad_image(image):
|
137 |
w, h = image.size
|
138 |
if w == h:
|
@@ -163,7 +166,34 @@ def validate_model_upload(hf_token, model_name):
|
|
163 |
if(model_name == ""):
|
164 |
raise gr.Error("Please fill in your model's name")
|
165 |
|
166 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
if is_shared_ui:
|
168 |
raise gr.Error("This Space only works in duplicated instances")
|
169 |
if not is_gpu_associated:
|
@@ -171,6 +201,9 @@ def train(*inputs):
|
|
171 |
hf_token = inputs[-5]
|
172 |
model_name = inputs[-7]
|
173 |
if(is_spaces):
|
|
|
|
|
|
|
174 |
remove_attribution_after = inputs[-6]
|
175 |
else:
|
176 |
remove_attribution_after = False
|
@@ -191,7 +224,6 @@ def train(*inputs):
|
|
191 |
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
192 |
if os.path.exists("hastrained.success"): os.remove("hastrained.success")
|
193 |
file_counter = 0
|
194 |
-
which_model = inputs[-10]
|
195 |
resolution = 512 if which_model != "v2-1-768" else 768
|
196 |
for i, input in enumerate(inputs):
|
197 |
if(i < maximum_concepts-1):
|
@@ -261,38 +293,22 @@ def train(*inputs):
|
|
261 |
print("Starting single training...")
|
262 |
lock_file = open("intraining.lock", "w")
|
263 |
lock_file.close()
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
mixed_precision="fp16",
|
281 |
-
train_batch_size=1,
|
282 |
-
gradient_accumulation_steps=1,
|
283 |
-
use_8bit_adam=True,
|
284 |
-
learning_rate=2e-6,
|
285 |
-
lr_scheduler="polynomial",
|
286 |
-
lr_warmup_steps = 0,
|
287 |
-
max_train_steps=Training_Steps,
|
288 |
-
num_class_images=200,
|
289 |
-
gradient_checkpointing=gradient_checkpointing,
|
290 |
-
cache_latents=cache_latents,
|
291 |
-
)
|
292 |
-
print("Starting multi-training...")
|
293 |
-
lock_file = open("intraining.lock", "w")
|
294 |
-
lock_file.close()
|
295 |
-
run_training(args_general)
|
296 |
gc.collect()
|
297 |
torch.cuda.empty_cache()
|
298 |
if(which_model == "v1-5"):
|
@@ -302,6 +318,7 @@ def train(*inputs):
|
|
302 |
shutil.copy(f"model_index.json", "output_model/model_index.json")
|
303 |
|
304 |
if(not remove_attribution_after):
|
|
|
305 |
print("Archiving model file...")
|
306 |
with tarfile.open("diffusers_model.tar", "w") as tar:
|
307 |
tar.add("output_model", arcname=os.path.basename("output_model"))
|
@@ -310,6 +327,7 @@ def train(*inputs):
|
|
310 |
trained_file.close()
|
311 |
print("Training completed!")
|
312 |
return [
|
|
|
313 |
gr.update(visible=True, value=["diffusers_model.tar"]), #result
|
314 |
gr.update(visible=True), #try_your_model
|
315 |
gr.update(visible=True), #push_to_hub
|
@@ -320,10 +338,7 @@ def train(*inputs):
|
|
320 |
else:
|
321 |
where_to_upload = inputs[-8]
|
322 |
push(model_name, where_to_upload, hf_token, which_model, True)
|
323 |
-
|
324 |
-
headers = { "authorization" : f"Bearer {hf_token}"}
|
325 |
-
body = {'flavor': 'cpu-basic'}
|
326 |
-
requests.post(hardware_url, json = body, headers=headers)
|
327 |
|
328 |
pipe_is_set = False
|
329 |
def generate(prompt, steps):
|
@@ -338,7 +353,7 @@ def generate(prompt, steps):
|
|
338 |
|
339 |
image = pipe(prompt, num_inference_steps=steps).images[0]
|
340 |
return(image)
|
341 |
-
|
342 |
def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
|
343 |
validate_model_upload(hf_token, model_name)
|
344 |
if(not os.path.exists("model.ckpt")):
|
@@ -425,7 +440,10 @@ Sample pictures of:
|
|
425 |
extra_message = "Don't forget to remove the GPU attribution after you play with it."
|
426 |
else:
|
427 |
extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
|
428 |
-
|
|
|
|
|
|
|
429 |
print("Model uploaded successfully!")
|
430 |
return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
|
431 |
|
@@ -488,8 +506,8 @@ with gr.Blocks(css=css) as demo:
|
|
488 |
<div class="gr-prose" style="max-width: 80%">
|
489 |
<h2>Attention - This Space doesn't work in this shared UI</h2>
|
490 |
<p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it! <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></p>
|
491 |
-
<img class="instruction" src="file
|
492 |
-
<img class="arrow" src="file
|
493 |
</div>
|
494 |
''')
|
495 |
elif(is_spaces):
|
@@ -519,14 +537,15 @@ with gr.Blocks(css=css) as demo:
|
|
519 |
|
520 |
with gr.Row() as what_are_you_training:
|
521 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
522 |
-
|
523 |
-
|
|
|
524 |
#Very hacky approach to emulate dynamically created Gradio components
|
525 |
with gr.Row() as upload_your_concept:
|
526 |
with gr.Column():
|
527 |
thing_description = gr.Markdown("You are going to train an `object`, please upload 5-10 images of the object you are planning on training on from different angles/perspectives. You must have the right to do so and you are liable for the images you use, example")
|
528 |
thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
|
529 |
-
thing_image_example = gr.HTML('''<img src="file
|
530 |
things_naming = gr.Markdown("You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `cttoy` here). Images will be automatically cropped to 512x512.")
|
531 |
|
532 |
with gr.Column():
|
@@ -588,6 +607,7 @@ with gr.Blocks(css=css) as demo:
|
|
588 |
training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=True)
|
589 |
|
590 |
train_btn = gr.Button("Start Training")
|
|
|
591 |
if(is_shared_ui):
|
592 |
training_ongoing = gr.Markdown("## This Space only works in duplicated instances. Please duplicate it and try again!", visible=False)
|
593 |
elif(not is_gpu_associated):
|
@@ -595,6 +615,7 @@ with gr.Blocks(css=css) as demo:
|
|
595 |
else:
|
596 |
training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
|
597 |
|
|
|
598 |
#Post-training UI
|
599 |
completed_training = gr.Markdown('''# ✅ Training completed.
|
600 |
### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
|
@@ -624,9 +645,10 @@ with gr.Blocks(css=css) as demo:
|
|
624 |
type_of_thing.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
625 |
|
626 |
#Swap the base model
|
|
|
627 |
base_model_to_use.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
|
|
628 |
base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
|
629 |
-
|
630 |
#Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
|
631 |
for file in file_collection:
|
632 |
#file.change(fn=update_steps,inputs=file_collection, outputs=steps)
|
@@ -641,10 +663,12 @@ with gr.Blocks(css=css) as demo:
|
|
641 |
if(is_spaces):
|
642 |
training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
|
643 |
#Add a message for while it is in training
|
644 |
-
|
|
|
645 |
|
646 |
#The main train function
|
647 |
-
train_btn.click(
|
|
|
648 |
|
649 |
#Button to generate an image from your trained model after training
|
650 |
generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)
|
|
|
35 |
is_shared_ui = False
|
36 |
is_gpu_associated = torch.cuda.is_available()
|
37 |
|
38 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
|
|
40 |
if(is_gpu_associated):
|
41 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
42 |
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-1", ignore_patterns=["*.ckpt", "*.safetensors"])
|
|
|
44 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
45 |
model_to_load = model_v1
|
46 |
|
47 |
+
def swap_base_model(selected_model):
|
48 |
+
if(is_gpu_associated):
|
49 |
+
global model_to_load
|
50 |
+
if(selected_model == "v1-5"):
|
51 |
+
model_to_load = model_v1
|
52 |
+
elif(selected_model == "v2-1-768"):
|
53 |
+
model_to_load = model_v2
|
54 |
+
else:
|
55 |
+
model_to_load = model_v2_512
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
css = '''
|
60 |
+
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
|
61 |
+
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
|
62 |
+
#component-4, #component-3, #component-10{min-height: 0}
|
63 |
+
.duplicate-button img{margin: 0}
|
64 |
+
'''
|
65 |
+
maximum_concepts = 3
|
66 |
|
67 |
def swap_text(option, base):
|
68 |
resize_width = 768 if base == "v2-1-768" else 512
|
|
|
70 |
if(option == "object"):
|
71 |
instance_prompt_example = "cttoy"
|
72 |
freeze_for = 30
|
73 |
+
return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=cat-toy.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, gr.update(visible=False)]
|
74 |
elif(option == "person"):
|
75 |
instance_prompt_example = "julcto"
|
76 |
freeze_for = 70
|
|
|
80 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation)
|
81 |
else:
|
82 |
prior_preservation_box_update = gr.update(visible=show_prior_preservation, value=False)
|
83 |
+
return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=person.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, prior_preservation_box_update]
|
84 |
elif(option == "style"):
|
85 |
instance_prompt_example = "trsldamrl"
|
86 |
freeze_for = 10
|
87 |
+
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=trsl_style.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}", freeze_for, gr.update(visible=False)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def count_files(*inputs):
|
90 |
file_counter = 0
|
91 |
concept_counter = 0
|
92 |
for i, input in enumerate(inputs):
|
93 |
+
if(i < maximum_concepts):
|
94 |
files = inputs[i]
|
95 |
if(files):
|
96 |
concept_counter+=1
|
|
|
133 |
file_counter+=len(files)
|
134 |
return(gr.update(value=file_counter*200))
|
135 |
|
136 |
+
def visualise_progress_bar():
|
137 |
+
return gr.update(visible=True)
|
138 |
+
|
139 |
def pad_image(image):
|
140 |
w, h = image.size
|
141 |
if w == h:
|
|
|
166 |
if(model_name == ""):
|
167 |
raise gr.Error("Please fill in your model's name")
|
168 |
|
169 |
+
def swap_hardware(hf_token, hardware="cpu-basic"):
|
170 |
+
hardware_url = f"https://huggingface.co/spaces/{os.environ['SPACE_ID']}/hardware"
|
171 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
172 |
+
body = {'flavor': hardware}
|
173 |
+
requests.post(hardware_url, json = body, headers=headers)
|
174 |
+
|
175 |
+
def swap_sleep_time(hf_token,sleep_time):
|
176 |
+
sleep_time_url = f"https://huggingface.co/api/spaces/{os.environ['SPACE_ID']}/sleeptime"
|
177 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
178 |
+
body = {'seconds':sleep_time}
|
179 |
+
requests.post(sleep_time_url,json=body,headers=headers)
|
180 |
+
|
181 |
+
def get_sleep_time(hf_token):
|
182 |
+
sleep_time_url = f"https://huggingface.co/api/spaces/{os.environ['SPACE_ID']}"
|
183 |
+
headers = { "authorization" : f"Bearer {hf_token}"}
|
184 |
+
response = requests.get(sleep_time_url,headers=headers)
|
185 |
+
return response.json()['runtime']['gcTimeout']
|
186 |
+
|
187 |
+
def write_to_community(title, description,hf_token):
|
188 |
+
from huggingface_hub import HfApi
|
189 |
+
api = HfApi()
|
190 |
+
api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
|
191 |
+
|
192 |
+
def train(progress=gr.Progress(track_tqdm=True), *inputs):
|
193 |
+
which_model = inputs[-10]
|
194 |
+
if(which_model == ""):
|
195 |
+
raise gr.Error("You forgot to select a base model to use")
|
196 |
+
|
197 |
if is_shared_ui:
|
198 |
raise gr.Error("This Space only works in duplicated instances")
|
199 |
if not is_gpu_associated:
|
|
|
201 |
hf_token = inputs[-5]
|
202 |
model_name = inputs[-7]
|
203 |
if(is_spaces):
|
204 |
+
sleep_time = get_sleep_time(hf_token)
|
205 |
+
if sleep_time:
|
206 |
+
swap_sleep_time(hf_token, -1)
|
207 |
remove_attribution_after = inputs[-6]
|
208 |
else:
|
209 |
remove_attribution_after = False
|
|
|
224 |
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
225 |
if os.path.exists("hastrained.success"): os.remove("hastrained.success")
|
226 |
file_counter = 0
|
|
|
227 |
resolution = 512 if which_model != "v2-1-768" else 768
|
228 |
for i, input in enumerate(inputs):
|
229 |
if(i < maximum_concepts-1):
|
|
|
293 |
print("Starting single training...")
|
294 |
lock_file = open("intraining.lock", "w")
|
295 |
lock_file.close()
|
296 |
+
try:
|
297 |
+
run_training(args_general)
|
298 |
+
except Exception as e:
|
299 |
+
if(is_spaces):
|
300 |
+
title="There was an error on during your training"
|
301 |
+
description=f'''
|
302 |
+
Unfortunately there was an error during training your {model_name} model.
|
303 |
+
Please check it out below. Feel free to report this issue to [Dreambooth Training](https://huggingface.co/spaces/multimodalart/dreambooth-training):
|
304 |
+
```
|
305 |
+
{str(e)}
|
306 |
+
```
|
307 |
+
'''
|
308 |
+
swap_hardware(hf_token, "cpu-basic")
|
309 |
+
write_to_community(title,description,hf_token)
|
310 |
+
|
311 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
gc.collect()
|
313 |
torch.cuda.empty_cache()
|
314 |
if(which_model == "v1-5"):
|
|
|
318 |
shutil.copy(f"model_index.json", "output_model/model_index.json")
|
319 |
|
320 |
if(not remove_attribution_after):
|
321 |
+
swap_sleep_time(hf_token, sleep_time)
|
322 |
print("Archiving model file...")
|
323 |
with tarfile.open("diffusers_model.tar", "w") as tar:
|
324 |
tar.add("output_model", arcname=os.path.basename("output_model"))
|
|
|
327 |
trained_file.close()
|
328 |
print("Training completed!")
|
329 |
return [
|
330 |
+
gr.update(visible=False), #progress_bar
|
331 |
gr.update(visible=True, value=["diffusers_model.tar"]), #result
|
332 |
gr.update(visible=True), #try_your_model
|
333 |
gr.update(visible=True), #push_to_hub
|
|
|
338 |
else:
|
339 |
where_to_upload = inputs[-8]
|
340 |
push(model_name, where_to_upload, hf_token, which_model, True)
|
341 |
+
swap_hardware(hf_token, "cpu-basic")
|
|
|
|
|
|
|
342 |
|
343 |
pipe_is_set = False
|
344 |
def generate(prompt, steps):
|
|
|
353 |
|
354 |
image = pipe(prompt, num_inference_steps=steps).images[0]
|
355 |
return(image)
|
356 |
+
|
357 |
def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
|
358 |
validate_model_upload(hf_token, model_name)
|
359 |
if(not os.path.exists("model.ckpt")):
|
|
|
440 |
extra_message = "Don't forget to remove the GPU attribution after you play with it."
|
441 |
else:
|
442 |
extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
|
443 |
+
title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!"
|
444 |
+
description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}"
|
445 |
+
write_to_community(title, description, hf_token)
|
446 |
+
#api.create_discussion(repo_id=os.environ['SPACE_ID'], title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!", description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}",repo_type="space", token=hf_token)
|
447 |
print("Model uploaded successfully!")
|
448 |
return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
|
449 |
|
|
|
506 |
<div class="gr-prose" style="max-width: 80%">
|
507 |
<h2>Attention - This Space doesn't work in this shared UI</h2>
|
508 |
<p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it! <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></p>
|
509 |
+
<img class="instruction" src="file=duplicate.png">
|
510 |
+
<img class="arrow" src="file=arrow.png" />
|
511 |
</div>
|
512 |
''')
|
513 |
elif(is_spaces):
|
|
|
537 |
|
538 |
with gr.Row() as what_are_you_training:
|
539 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
540 |
+
with gr.Column():
|
541 |
+
base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-1-512", "v2-1-768"], value="v1-5", interactive=True)
|
542 |
+
|
543 |
#Very hacky approach to emulate dynamically created Gradio components
|
544 |
with gr.Row() as upload_your_concept:
|
545 |
with gr.Column():
|
546 |
thing_description = gr.Markdown("You are going to train an `object`, please upload 5-10 images of the object you are planning on training on from different angles/perspectives. You must have the right to do so and you are liable for the images you use, example")
|
547 |
thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
|
548 |
+
thing_image_example = gr.HTML('''<img src="file=cat-toy.png" />''')
|
549 |
things_naming = gr.Markdown("You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `cttoy` here). Images will be automatically cropped to 512x512.")
|
550 |
|
551 |
with gr.Column():
|
|
|
607 |
training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=True)
|
608 |
|
609 |
train_btn = gr.Button("Start Training")
|
610 |
+
progress_bar = gr.Textbox(visible=False)
|
611 |
if(is_shared_ui):
|
612 |
training_ongoing = gr.Markdown("## This Space only works in duplicated instances. Please duplicate it and try again!", visible=False)
|
613 |
elif(not is_gpu_associated):
|
|
|
615 |
else:
|
616 |
training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
|
617 |
|
618 |
+
|
619 |
#Post-training UI
|
620 |
completed_training = gr.Markdown('''# ✅ Training completed.
|
621 |
### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
|
|
|
645 |
type_of_thing.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
646 |
|
647 |
#Swap the base model
|
648 |
+
|
649 |
base_model_to_use.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
|
650 |
+
#base_model_to_use.change(fn=visualise_progress_bar, inputs=[], outputs=progress_bar)
|
651 |
base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
|
|
|
652 |
#Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
|
653 |
for file in file_collection:
|
654 |
#file.change(fn=update_steps,inputs=file_collection, outputs=steps)
|
|
|
663 |
if(is_spaces):
|
664 |
training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
|
665 |
#Add a message for while it is in training
|
666 |
+
|
667 |
+
#train_btn.click(lambda:gr.update(visible=True), inputs=None, outputs=training_ongoing)
|
668 |
|
669 |
#The main train function
|
670 |
+
train_btn.click(lambda:gr.update(visible=True), inputs=[], outputs=progress_bar)
|
671 |
+
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[base_model_to_use]+[thing_experimental]+[training_summary_where_to_upload]+[training_summary_model_name]+[training_summary_checkbox]+[training_summary_token]+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[progress_bar, result, try_your_model, push_to_hub, convert_button, training_ongoing, completed_training], queue=False)
|
672 |
|
673 |
#Button to generate an image from your trained model after training
|
674 |
generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)
|