Update app.py
Browse files
app.py
CHANGED
|
@@ -316,74 +316,71 @@ def randomize_loras(selected_indices, loras_state):
|
|
| 316 |
random_prompt = random.choice(prompt_values)
|
| 317 |
return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, random_prompt
|
| 318 |
|
| 319 |
-
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
|
| 320 |
-
if custom_lora:
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
"title": title,
|
| 331 |
-
"repo": repo,
|
| 332 |
-
"weights": path,
|
| 333 |
-
"trigger_word": trigger_word
|
| 334 |
-
}
|
| 335 |
-
print(f"New LoRA: {new_item}")
|
| 336 |
-
existing_item_index = len(current_loras)
|
| 337 |
-
current_loras.append(new_item)
|
| 338 |
-
|
| 339 |
-
# Update gallery
|
| 340 |
-
gallery_items = [(item["image"], item["title"]) for item in current_loras]
|
| 341 |
-
# Update selected_indices if there's room
|
| 342 |
-
if len(selected_indices) < 4:
|
| 343 |
-
selected_indices.append(existing_item_index)
|
| 344 |
-
else:
|
| 345 |
-
gr.Warning("You can select up to 4 LoRAs, remove one to select a new one.")
|
| 346 |
-
|
| 347 |
-
# Update selected_info and images
|
| 348 |
-
selected_info_1 = "Select a Celebrity as LoRA 1"
|
| 349 |
-
selected_info_2 = "Select a LoRA 2"
|
| 350 |
-
selected_info_3 = "Select a LoRA 3"
|
| 351 |
-
selected_info_4 = "Select a LoRA 4"
|
| 352 |
-
lora_scale_1 = 1.15
|
| 353 |
-
lora_scale_2 = 1.15
|
| 354 |
-
lora_scale_3 = 0.65
|
| 355 |
-
lora_scale_4 = 0.65
|
| 356 |
-
lora_image_1 = None
|
| 357 |
-
lora_image_2 = None
|
| 358 |
-
lora_image_3 = None
|
| 359 |
-
lora_image_4 = None
|
| 360 |
-
if len(selected_indices) >= 1:
|
| 361 |
-
lora1 = current_loras[selected_indices[0]]
|
| 362 |
-
selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
|
| 363 |
-
lora_image_1 = lora1['image'] if lora1['image'] else None
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
|
| 368 |
-
lora_image_2 = lora2['image'] if lora2['image'] else None
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
def remove_custom_lora(selected_indices, current_loras, gallery):
|
| 389 |
if current_loras:
|
|
@@ -519,10 +516,10 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
| 519 |
|
| 520 |
run_lora.zerogpu = True
|
| 521 |
|
| 522 |
-
def get_huggingface_safetensors(link):
|
| 523 |
split_link = link.split("/")
|
| 524 |
if len(split_link) == 2:
|
| 525 |
-
model_card = ModelCard.load(link)
|
| 526 |
base_model = model_card.data.get("base_model")
|
| 527 |
print(f"Base model: {base_model}")
|
| 528 |
if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
|
|
@@ -530,7 +527,7 @@ def get_huggingface_safetensors(link):
|
|
| 530 |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
| 531 |
trigger_word = model_card.data.get("instance_prompt", "")
|
| 532 |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
| 533 |
-
fs = HfFileSystem()
|
| 534 |
safetensors_name = None
|
| 535 |
try:
|
| 536 |
list_of_files = fs.ls(link, detail=False)
|
|
@@ -549,24 +546,22 @@ def get_huggingface_safetensors(link):
|
|
| 549 |
else:
|
| 550 |
raise gr.Error("Invalid Hugging Face repository link")
|
| 551 |
|
| 552 |
-
def check_custom_model(link):
|
| 553 |
if link.endswith(".safetensors"):
|
| 554 |
-
# Treat as direct link to the LoRA weights
|
| 555 |
title = os.path.basename(link)
|
| 556 |
repo = link
|
| 557 |
-
path = None
|
| 558 |
trigger_word = ""
|
| 559 |
image_url = None
|
| 560 |
return title, repo, path, trigger_word, image_url
|
| 561 |
elif link.startswith("https://"):
|
| 562 |
if "huggingface.co" in link:
|
| 563 |
link_split = link.split("huggingface.co/")
|
| 564 |
-
return get_huggingface_safetensors(link_split[1])
|
| 565 |
else:
|
| 566 |
raise Exception("Unsupported URL")
|
| 567 |
else:
|
| 568 |
-
|
| 569 |
-
return get_huggingface_safetensors(link)
|
| 570 |
|
| 571 |
def update_history(new_image, history):
|
| 572 |
"""Updates the history gallery with the new image."""
|
|
|
|
| 316 |
random_prompt = random.choice(prompt_values)
|
| 317 |
return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, random_prompt
|
| 318 |
|
| 319 |
+
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery, request: gr.Request = None):
|
| 320 |
+
if not custom_lora:
|
| 321 |
+
return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
# Retrieve user token if running in Spaces
|
| 325 |
+
user_token = request.headers.get("Authorization", "").replace("Bearer ", "") if request else None
|
| 326 |
+
|
| 327 |
+
# Check and load custom LoRA
|
| 328 |
+
title, repo, path, trigger_word, image = check_custom_model(custom_lora, token=user_token)
|
| 329 |
+
print(f"Loaded custom LoRA: {repo}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
+
# Check if the LoRA already exists in the current list
|
| 332 |
+
existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
if existing_item_index is None:
|
| 335 |
+
# Download if a direct .safetensors URL
|
| 336 |
+
if repo.endswith(".safetensors") and repo.startswith("http"):
|
| 337 |
+
repo = download_file(repo)
|
| 338 |
+
|
| 339 |
+
# Add the new LoRA
|
| 340 |
+
new_item = {
|
| 341 |
+
"image": image or "/home/user/app/custom.png",
|
| 342 |
+
"title": title,
|
| 343 |
+
"repo": repo,
|
| 344 |
+
"weights": path,
|
| 345 |
+
"trigger_word": trigger_word,
|
| 346 |
+
}
|
| 347 |
+
print(f"New LoRA: {new_item}")
|
| 348 |
+
existing_item_index = len(current_loras)
|
| 349 |
+
current_loras.append(new_item)
|
| 350 |
|
| 351 |
+
# Update gallery items
|
| 352 |
+
gallery_items = [(item["image"], item["title"]) for item in current_loras]
|
| 353 |
+
|
| 354 |
+
# Update selected indices
|
| 355 |
+
if len(selected_indices) < 4:
|
| 356 |
+
selected_indices.append(existing_item_index)
|
| 357 |
+
else:
|
| 358 |
+
raise gr.Error("You can select up to 4 LoRAs. Please remove one to add a new one.")
|
| 359 |
+
|
| 360 |
+
# Update selection info and images
|
| 361 |
+
selected_info = [f"Select a LoRA {i + 1}" for i in range(4)]
|
| 362 |
+
lora_images = [None] * 4
|
| 363 |
+
lora_scales = [1.15, 1.15, 0.65, 0.65]
|
| 364 |
+
|
| 365 |
+
for idx, sel_idx in enumerate(selected_indices[:4]):
|
| 366 |
+
lora = current_loras[sel_idx]
|
| 367 |
+
selected_info[idx] = f"### LoRA {idx + 1} Selected: {lora['title']} ✨"
|
| 368 |
+
lora_images[idx] = lora.get("image")
|
| 369 |
+
|
| 370 |
+
print("Finished adding custom LoRA")
|
| 371 |
+
return (
|
| 372 |
+
current_loras,
|
| 373 |
+
gr.update(value=gallery_items),
|
| 374 |
+
*selected_info,
|
| 375 |
+
selected_indices,
|
| 376 |
+
*lora_scales,
|
| 377 |
+
*lora_images,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(e)
|
| 382 |
+
return (current_loras, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),gr.update(),
|
| 383 |
+
)
|
| 384 |
|
| 385 |
def remove_custom_lora(selected_indices, current_loras, gallery):
|
| 386 |
if current_loras:
|
|
|
|
| 516 |
|
| 517 |
run_lora.zerogpu = True
|
| 518 |
|
| 519 |
+
def get_huggingface_safetensors(link, token=None):
|
| 520 |
split_link = link.split("/")
|
| 521 |
if len(split_link) == 2:
|
| 522 |
+
model_card = ModelCard.load(link, use_auth_token=token)
|
| 523 |
base_model = model_card.data.get("base_model")
|
| 524 |
print(f"Base model: {base_model}")
|
| 525 |
if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
|
|
|
|
| 527 |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
| 528 |
trigger_word = model_card.data.get("instance_prompt", "")
|
| 529 |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
| 530 |
+
fs = HfFileSystem(token=token)
|
| 531 |
safetensors_name = None
|
| 532 |
try:
|
| 533 |
list_of_files = fs.ls(link, detail=False)
|
|
|
|
| 546 |
else:
|
| 547 |
raise gr.Error("Invalid Hugging Face repository link")
|
| 548 |
|
| 549 |
+
def check_custom_model(link, token=None):
|
| 550 |
if link.endswith(".safetensors"):
|
|
|
|
| 551 |
title = os.path.basename(link)
|
| 552 |
repo = link
|
| 553 |
+
path = None
|
| 554 |
trigger_word = ""
|
| 555 |
image_url = None
|
| 556 |
return title, repo, path, trigger_word, image_url
|
| 557 |
elif link.startswith("https://"):
|
| 558 |
if "huggingface.co" in link:
|
| 559 |
link_split = link.split("huggingface.co/")
|
| 560 |
+
return get_huggingface_safetensors(link_split[1], token=token)
|
| 561 |
else:
|
| 562 |
raise Exception("Unsupported URL")
|
| 563 |
else:
|
| 564 |
+
return get_huggingface_safetensors(link, token=token)
|
|
|
|
| 565 |
|
| 566 |
def update_history(new_image, history):
|
| 567 |
"""Updates the history gallery with the new image."""
|