Spaces:
Running
Running
The input type now supports text files in addition to images.
Browse filesWhen a text file is used as input, you can perform actions such as "reorganize the article" and "Prepend/Append Additional tags".
app.py
CHANGED
|
@@ -17,14 +17,16 @@ from datetime import datetime
|
|
| 17 |
from collections import defaultdict
|
| 18 |
from classifyTags import classify_tags
|
| 19 |
|
| 20 |
-
TITLE = "WaifuDiffusion Tagger multiple images"
|
| 21 |
DESCRIPTION = """
|
| 22 |
-
Demo for the WaifuDiffusion tagger models
|
|
|
|
| 23 |
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
| 24 |
|
|
|
|
| 25 |
Features of This Modified Version:
|
| 26 |
-
- Supports batch processing of multiple images
|
| 27 |
-
- Displays tag results in categorized groups: the generated tags will now be analyzed and categorized into corresponding groups.
|
| 28 |
"""
|
| 29 |
|
| 30 |
# Dataset v3 series of models:
|
|
@@ -124,33 +126,34 @@ class Timer:
|
|
| 124 |
|
| 125 |
def report(self, is_clear_checkpoints = True):
|
| 126 |
# Determine the max label width for alignment
|
| 127 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
| 134 |
|
| 135 |
if is_clear_checkpoints:
|
| 136 |
-
self.checkpoints.
|
| 137 |
-
self.checkpoint() # Store checkpoints
|
| 138 |
|
| 139 |
def report_all(self):
|
| 140 |
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
| 141 |
print("\n> Execution Time Report:")
|
| 142 |
|
| 143 |
# Determine the max label width for alignment
|
| 144 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints) if
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
self.checkpoints.clear()
|
| 156 |
|
|
@@ -384,12 +387,12 @@ class Predictor:
|
|
| 384 |
|
| 385 |
def create_file(self, text: str, directory: str, fileName: str) -> str:
|
| 386 |
# Write the text to a file
|
| 387 |
-
|
|
|
|
| 388 |
file.write(text)
|
|
|
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
def predict(
|
| 393 |
self,
|
| 394 |
gallery,
|
| 395 |
model_repo,
|
|
@@ -404,34 +407,36 @@ class Predictor:
|
|
| 404 |
tag_results,
|
| 405 |
progress=gr.Progress()
|
| 406 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
gallery_len = len(gallery)
|
| 408 |
-
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
|
| 409 |
|
| 410 |
timer = Timer() # Create a timer
|
| 411 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 412 |
-
progressTotal = gallery_len + 1
|
| 413 |
current_progress = 0
|
| 414 |
|
| 415 |
self.load_model(model_repo)
|
| 416 |
-
current_progress +=
|
| 417 |
progress(current_progress, desc="Initialize wd model finished")
|
| 418 |
timer.checkpoint(f"Initialize wd model")
|
| 419 |
-
|
| 420 |
# Result
|
| 421 |
txt_infos = []
|
| 422 |
output_dir = tempfile.mkdtemp()
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
rating = None
|
| 428 |
-
character_res = None
|
| 429 |
-
general_res = None
|
| 430 |
|
|
|
|
| 431 |
if llama3_reorganize_model_repo:
|
| 432 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
| 433 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
| 434 |
-
current_progress +=
|
| 435 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 436 |
timer.checkpoint(f"Initialize llama3 model")
|
| 437 |
|
|
@@ -458,7 +463,7 @@ class Predictor:
|
|
| 458 |
|
| 459 |
input_name = self.model.get_inputs()[0].name
|
| 460 |
label_name = self.model.get_outputs()[0].name
|
| 461 |
-
print(f"Gallery {idx
|
| 462 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 463 |
|
| 464 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
@@ -473,9 +478,7 @@ class Predictor:
|
|
| 473 |
if general_mcut_enabled:
|
| 474 |
general_probs = np.array([x[1] for x in general_names])
|
| 475 |
general_thresh = mcut_threshold(general_probs)
|
| 476 |
-
|
| 477 |
-
general_res = [x for x in general_names if x[1] > general_thresh]
|
| 478 |
-
general_res = dict(general_res)
|
| 479 |
|
| 480 |
# Everything else is characters: pick any where prediction confidence > threshold
|
| 481 |
character_names = [labels[i] for i in self.character_indexes]
|
|
@@ -484,16 +487,10 @@ class Predictor:
|
|
| 484 |
character_probs = np.array([x[1] for x in character_names])
|
| 485 |
character_thresh = mcut_threshold(character_probs)
|
| 486 |
character_thresh = max(0.15, character_thresh)
|
| 487 |
-
|
| 488 |
-
character_res = [x for x in character_names if x[1] > character_thresh]
|
| 489 |
-
character_res = dict(character_res)
|
| 490 |
character_list = list(character_res.keys())
|
| 491 |
|
| 492 |
-
sorted_general_list = sorted(
|
| 493 |
-
general_res.items(),
|
| 494 |
-
key=lambda x: x[1],
|
| 495 |
-
reverse=True,
|
| 496 |
-
)
|
| 497 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
| 498 |
#Remove values from character_list that already exist in sorted_general_list
|
| 499 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
|
@@ -503,57 +500,181 @@ class Predictor:
|
|
| 503 |
if append_list:
|
| 504 |
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
|
|
|
| 511 |
|
| 512 |
-
current_progress += progressRatio/progressTotal
|
| 513 |
-
progress(current_progress, desc=f"
|
| 514 |
-
timer.checkpoint(f"
|
| 515 |
|
| 516 |
-
if
|
| 517 |
print(f"Starting reorganize with llama3...")
|
| 518 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
| 519 |
-
reorganize_strings
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
|
|
|
| 523 |
|
| 524 |
-
current_progress += progressRatio/progressTotal
|
| 525 |
-
progress(current_progress, desc=f"
|
| 526 |
-
timer.checkpoint(f"
|
| 527 |
|
| 528 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
| 529 |
-
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
| 530 |
|
| 531 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
timer.report()
|
|
|
|
| 533 |
except Exception as e:
|
| 534 |
print(traceback.format_exc())
|
| 535 |
-
print("Error
|
| 536 |
-
|
|
|
|
| 537 |
# Result
|
| 538 |
download = []
|
| 539 |
-
if txt_infos
|
| 540 |
-
|
|
|
|
| 541 |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
|
| 542 |
for info in txt_infos:
|
| 543 |
# Get file name from lookup
|
| 544 |
taggers_zip.write(info["path"], arcname=info["name"])
|
| 545 |
download.append(downloadZipPath)
|
| 546 |
|
| 547 |
-
if
|
| 548 |
llama3_reorganize.release_vram()
|
| 549 |
-
del llama3_reorganize
|
| 550 |
|
| 551 |
-
progress(1, desc=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
timer.report_all() # Print all recorded times
|
| 553 |
-
print("
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
-
return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
| 556 |
-
|
| 557 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
| 558 |
if not selected_state:
|
| 559 |
return selected_state
|
|
@@ -590,10 +711,15 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
|
|
| 590 |
if not gallery or not selected_image:
|
| 591 |
return gallery
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
gallery
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
return gallery
|
| 598 |
|
| 599 |
|
|
@@ -605,7 +731,6 @@ def main():
|
|
| 605 |
width: 55.5% !important;
|
| 606 |
}
|
| 607 |
"""
|
| 608 |
-
|
| 609 |
args = parse_args()
|
| 610 |
|
| 611 |
predictor = Predictor()
|
|
@@ -626,34 +751,93 @@ def main():
|
|
| 626 |
SWINV2_MODEL_IS_DSV1_REPO,
|
| 627 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
| 628 |
]
|
| 629 |
-
|
| 630 |
llama_list = [
|
| 631 |
META_LLAMA_3_3B_REPO,
|
| 632 |
META_LLAMA_3_8B_REPO,
|
| 633 |
]
|
| 634 |
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
gr.Markdown(value=DESCRIPTION)
|
|
|
|
| 640 |
with gr.Row():
|
| 641 |
with gr.Column():
|
| 642 |
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
model_repo = gr.Dropdown(
|
| 652 |
dropdown_list,
|
| 653 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
| 654 |
-
label="Model",
|
| 655 |
)
|
| 656 |
-
with gr.Row():
|
| 657 |
general_thresh = gr.Slider(
|
| 658 |
0,
|
| 659 |
1,
|
|
@@ -667,7 +851,7 @@ def main():
|
|
| 667 |
label="Use MCut threshold",
|
| 668 |
scale=1,
|
| 669 |
)
|
| 670 |
-
with gr.Row():
|
| 671 |
character_thresh = gr.Slider(
|
| 672 |
0,
|
| 673 |
1,
|
|
@@ -681,18 +865,20 @@ def main():
|
|
| 681 |
label="Use MCut threshold",
|
| 682 |
scale=1,
|
| 683 |
)
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
|
|
|
|
|
|
| 690 |
with gr.Row():
|
| 691 |
llama3_reorganize_model_repo = gr.Dropdown(
|
| 692 |
[None] + llama_list,
|
| 693 |
value=None,
|
| 694 |
-
label="Llama3
|
| 695 |
-
info="
|
| 696 |
)
|
| 697 |
with gr.Row():
|
| 698 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
|
@@ -701,6 +887,7 @@ def main():
|
|
| 701 |
clear = gr.ClearButton(
|
| 702 |
components=[
|
| 703 |
gallery,
|
|
|
|
| 704 |
model_repo,
|
| 705 |
general_thresh,
|
| 706 |
general_mcut_enabled,
|
|
@@ -714,14 +901,16 @@ def main():
|
|
| 714 |
variant="secondary",
|
| 715 |
size="lg",
|
| 716 |
)
|
|
|
|
| 717 |
with gr.Column(variant="panel"):
|
| 718 |
download_file = gr.File(label="Output (Download)")
|
| 719 |
-
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True)
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
|
|
|
| 725 |
clear.add(
|
| 726 |
[
|
| 727 |
download_file,
|
|
@@ -733,35 +922,51 @@ def main():
|
|
| 733 |
unclassified,
|
| 734 |
]
|
| 735 |
)
|
| 736 |
-
|
| 737 |
tag_results = gr.State({})
|
|
|
|
|
|
|
|
|
|
| 738 |
# Define the event listener to add the uploaded image to the gallery
|
| 739 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
| 740 |
# When the upload button is clicked, add the new images to the gallery
|
| 741 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
| 742 |
# Event to update the selected image when an image is clicked in the gallery
|
| 743 |
-
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 744 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
| 745 |
# Event to remove a selected image from the gallery
|
| 746 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
|
|
|
|
|
|
|
|
|
| 765 |
|
| 766 |
gr.Examples(
|
| 767 |
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
|
|
|
|
| 17 |
from collections import defaultdict
|
| 18 |
from classifyTags import classify_tags
|
| 19 |
|
| 20 |
+
TITLE = "WaifuDiffusion Tagger multiple images/texts"
|
| 21 |
DESCRIPTION = """
|
| 22 |
+
Demo for the WaifuDiffusion tagger models and text processing.
|
| 23 |
+
Select input type below. For images, it will generate tags. For text files, it will process existing tags.
|
| 24 |
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
| 25 |
|
| 26 |
+
This project was duplicated from the Space of [wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) by the author SmilingWolf.
|
| 27 |
Features of This Modified Version:
|
| 28 |
+
- Supports batch processing of multiple images or text files.
|
| 29 |
+
- Displays tag results in categorized groups: the generated tags will now be analyzed and categorized into corresponding groups. (for images)
|
| 30 |
"""
|
| 31 |
|
| 32 |
# Dataset v3 series of models:
|
|
|
|
| 126 |
|
| 127 |
def report(self, is_clear_checkpoints = True):
|
| 128 |
# Determine the max label width for alignment
|
| 129 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if self.checkpoints else 0
|
| 130 |
+
|
| 131 |
+
if len(self.checkpoints) > 1:
|
| 132 |
+
prev_time = self.checkpoints[0][1]
|
| 133 |
+
for label, curr_time in self.checkpoints[1:]:
|
| 134 |
+
elapsed = curr_time - prev_time
|
| 135 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 136 |
+
prev_time = curr_time
|
| 137 |
|
| 138 |
if is_clear_checkpoints:
|
| 139 |
+
self.checkpoints = [("Start", time.perf_counter())]
|
|
|
|
| 140 |
|
| 141 |
def report_all(self):
|
| 142 |
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
| 143 |
print("\n> Execution Time Report:")
|
| 144 |
|
| 145 |
# Determine the max label width for alignment
|
| 146 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if self.checkpoints else 0
|
| 147 |
+
|
| 148 |
+
if len(self.checkpoints) > 1:
|
| 149 |
+
prev_time = self.start_time
|
| 150 |
+
for label, curr_time in self.checkpoints[1:]:
|
| 151 |
+
elapsed = curr_time - prev_time
|
| 152 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 153 |
+
prev_time = curr_time
|
| 154 |
+
|
| 155 |
+
total_time = self.checkpoints[-1][1] - self.start_time
|
| 156 |
+
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
| 157 |
|
| 158 |
self.checkpoints.clear()
|
| 159 |
|
|
|
|
| 387 |
|
| 388 |
def create_file(self, text: str, directory: str, fileName: str) -> str:
|
| 389 |
# Write the text to a file
|
| 390 |
+
filepath = os.path.join(directory, fileName)
|
| 391 |
+
with open(filepath, 'w', encoding="utf-8") as file:
|
| 392 |
file.write(text)
|
| 393 |
+
return filepath
|
| 394 |
|
| 395 |
+
def predict_from_images(
|
|
|
|
|
|
|
| 396 |
self,
|
| 397 |
gallery,
|
| 398 |
model_repo,
|
|
|
|
| 407 |
tag_results,
|
| 408 |
progress=gr.Progress()
|
| 409 |
):
|
| 410 |
+
if not gallery:
|
| 411 |
+
gr.Warning("No images in the gallery to process.")
|
| 412 |
+
return None, "", "{}", "", "", "", "{}", {}
|
| 413 |
+
|
| 414 |
gallery_len = len(gallery)
|
| 415 |
+
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
| 416 |
|
| 417 |
timer = Timer() # Create a timer
|
| 418 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 419 |
+
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
| 420 |
current_progress = 0
|
| 421 |
|
| 422 |
self.load_model(model_repo)
|
| 423 |
+
current_progress += 1 / progressTotal
|
| 424 |
progress(current_progress, desc="Initialize wd model finished")
|
| 425 |
timer.checkpoint(f"Initialize wd model")
|
| 426 |
+
|
| 427 |
# Result
|
| 428 |
txt_infos = []
|
| 429 |
output_dir = tempfile.mkdtemp()
|
| 430 |
+
|
| 431 |
+
last_sorted_general_strings = ""
|
| 432 |
+
last_classified_tags, last_unclassified_tags = {}, {}
|
| 433 |
+
last_rating, last_character_res, last_general_res = None, None, None
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
+
llama3_reorganize = None
|
| 436 |
if llama3_reorganize_model_repo:
|
| 437 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
| 438 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
| 439 |
+
current_progress += 1 / progressTotal
|
| 440 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 441 |
timer.checkpoint(f"Initialize llama3 model")
|
| 442 |
|
|
|
|
| 463 |
|
| 464 |
input_name = self.model.get_inputs()[0].name
|
| 465 |
label_name = self.model.get_outputs()[0].name
|
| 466 |
+
print(f"Gallery {idx+1}/{gallery_len}: Starting run wd model...")
|
| 467 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 468 |
|
| 469 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
|
|
| 478 |
if general_mcut_enabled:
|
| 479 |
general_probs = np.array([x[1] for x in general_names])
|
| 480 |
general_thresh = mcut_threshold(general_probs)
|
| 481 |
+
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
|
|
|
|
|
|
| 482 |
|
| 483 |
# Everything else is characters: pick any where prediction confidence > threshold
|
| 484 |
character_names = [labels[i] for i in self.character_indexes]
|
|
|
|
| 487 |
character_probs = np.array([x[1] for x in character_names])
|
| 488 |
character_thresh = mcut_threshold(character_probs)
|
| 489 |
character_thresh = max(0.15, character_thresh)
|
| 490 |
+
character_res = dict([x for x in character_names if x[1] > character_thresh])
|
|
|
|
|
|
|
| 491 |
character_list = list(character_res.keys())
|
| 492 |
|
| 493 |
+
sorted_general_list = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
| 495 |
#Remove values from character_list that already exist in sorted_general_list
|
| 496 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
|
|
|
| 500 |
if append_list:
|
| 501 |
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
|
| 502 |
|
| 503 |
+
final_tags_list = prepend_list + sorted_general_list + append_list
|
| 504 |
+
if characters_merge_enabled:
|
| 505 |
+
final_tags_list = character_list + final_tags_list
|
| 506 |
+
|
| 507 |
+
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
| 508 |
+
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
| 509 |
|
| 510 |
+
current_progress += progressRatio / progressTotal
|
| 511 |
+
progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, predict finished")
|
| 512 |
+
timer.checkpoint(f"Image {idx+1}/{gallery_len}, predict finished")
|
| 513 |
|
| 514 |
+
if llama3_reorganize:
|
| 515 |
print(f"Starting reorganize with llama3...")
|
| 516 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
| 517 |
+
if reorganize_strings:
|
| 518 |
+
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
| 519 |
+
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
| 520 |
+
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
| 521 |
+
sorted_general_strings += "," + reorganize_strings
|
| 522 |
|
| 523 |
+
current_progress += progressRatio / progressTotal
|
| 524 |
+
progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, llama3 reorganize finished")
|
| 525 |
+
timer.checkpoint(f"Image {idx+1}/{gallery_len}, llama3 reorganize finished")
|
| 526 |
|
| 527 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
| 528 |
+
txt_infos.append({"path": txt_file, "name": image_name + ".txt"})
|
| 529 |
|
| 530 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
| 531 |
+
|
| 532 |
+
# Store last result for UI display
|
| 533 |
+
last_sorted_general_strings = sorted_general_strings
|
| 534 |
+
last_classified_tags = classified_tags
|
| 535 |
+
last_rating = rating
|
| 536 |
+
last_character_res = character_res
|
| 537 |
+
last_general_res = general_res
|
| 538 |
+
last_unclassified_tags = unclassified_tags
|
| 539 |
timer.report()
|
| 540 |
+
|
| 541 |
except Exception as e:
|
| 542 |
print(traceback.format_exc())
|
| 543 |
+
print("Error predicting image: " + str(e))
|
| 544 |
+
gr.Warning(f"Failed to process image {os.path.basename(value[0])}. Error: {e}")
|
| 545 |
+
|
| 546 |
# Result
|
| 547 |
download = []
|
| 548 |
+
if txt_infos:
|
| 549 |
+
zip_filename = "images-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip"
|
| 550 |
+
downloadZipPath = os.path.join(output_dir, zip_filename)
|
| 551 |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
|
| 552 |
for info in txt_infos:
|
| 553 |
# Get file name from lookup
|
| 554 |
taggers_zip.write(info["path"], arcname=info["name"])
|
| 555 |
download.append(downloadZipPath)
|
| 556 |
|
| 557 |
+
if llama3_reorganize:
|
| 558 |
llama3_reorganize.release_vram()
|
|
|
|
| 559 |
|
| 560 |
+
progress(1, desc="Image processing completed")
|
| 561 |
+
timer.report_all()
|
| 562 |
+
print("Image prediction is complete.")
|
| 563 |
+
|
| 564 |
+
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
| 565 |
+
|
| 566 |
+
# NEW: Method to process text files
|
| 567 |
+
def predict_from_text(
|
| 568 |
+
self,
|
| 569 |
+
text_files,
|
| 570 |
+
llama3_reorganize_model_repo,
|
| 571 |
+
additional_tags_prepend,
|
| 572 |
+
additional_tags_append,
|
| 573 |
+
progress=gr.Progress()
|
| 574 |
+
):
|
| 575 |
+
if not text_files:
|
| 576 |
+
gr.Warning("No text files uploaded to process.")
|
| 577 |
+
return None, "", "{}", "", "", "", "{}", {}
|
| 578 |
+
|
| 579 |
+
files_len = len(text_files)
|
| 580 |
+
print(f"Predict from text: processing {files_len} files.")
|
| 581 |
+
|
| 582 |
+
timer = Timer()
|
| 583 |
+
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
| 584 |
+
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
| 585 |
+
current_progress = 0
|
| 586 |
+
|
| 587 |
+
txt_infos = []
|
| 588 |
+
output_dir = tempfile.mkdtemp()
|
| 589 |
+
last_processed_string = ""
|
| 590 |
+
|
| 591 |
+
llama3_reorganize = None
|
| 592 |
+
if llama3_reorganize_model_repo:
|
| 593 |
+
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
| 594 |
+
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
| 595 |
+
current_progress += 1 / progressTotal
|
| 596 |
+
progress(current_progress, desc="Initialize llama3 model finished")
|
| 597 |
+
timer.checkpoint(f"Initialize llama3 model")
|
| 598 |
+
|
| 599 |
+
timer.report()
|
| 600 |
+
|
| 601 |
+
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 602 |
+
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
| 603 |
+
if prepend_list and append_list:
|
| 604 |
+
append_list = [item for item in append_list if item not in prepend_list]
|
| 605 |
+
|
| 606 |
+
name_counters = defaultdict(int)
|
| 607 |
+
for idx, file_obj in enumerate(text_files):
|
| 608 |
+
try:
|
| 609 |
+
file_path = file_obj.name
|
| 610 |
+
file_name_base = os.path.splitext(os.path.basename(file_path))[0]
|
| 611 |
+
|
| 612 |
+
name_counters[file_name_base] += 1
|
| 613 |
+
if name_counters[file_name_base] > 1:
|
| 614 |
+
output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
|
| 615 |
+
else:
|
| 616 |
+
output_file_name = f"{file_name_base}.txt"
|
| 617 |
+
|
| 618 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 619 |
+
original_content = f.read()
|
| 620 |
+
|
| 621 |
+
# Process tags
|
| 622 |
+
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
| 623 |
+
|
| 624 |
+
if prepend_list:
|
| 625 |
+
tags_list = [item for item in tags_list if item not in prepend_list]
|
| 626 |
+
if append_list:
|
| 627 |
+
tags_list = [item for item in tags_list if item not in append_list]
|
| 628 |
+
|
| 629 |
+
final_tags_list = prepend_list + tags_list + append_list
|
| 630 |
+
processed_string = ", ".join(final_tags_list)
|
| 631 |
+
|
| 632 |
+
current_progress += progressRatio / progressTotal
|
| 633 |
+
progress(current_progress, desc=f"File {idx+1}/{files_len}, base processing finished")
|
| 634 |
+
timer.checkpoint(f"File {idx+1}/{files_len}, base processing finished")
|
| 635 |
+
|
| 636 |
+
if llama3_reorganize:
|
| 637 |
+
print(f"Starting reorganize with llama3...")
|
| 638 |
+
reorganize_strings = llama3_reorganize.reorganize(processed_string)
|
| 639 |
+
if reorganize_strings:
|
| 640 |
+
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
| 641 |
+
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
| 642 |
+
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
| 643 |
+
processed_string += "," + reorganize_strings
|
| 644 |
+
|
| 645 |
+
current_progress += progressRatio / progressTotal
|
| 646 |
+
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 647 |
+
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 648 |
+
|
| 649 |
+
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
| 650 |
+
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
| 651 |
+
last_processed_string = processed_string
|
| 652 |
+
timer.report()
|
| 653 |
+
|
| 654 |
+
except Exception as e:
|
| 655 |
+
print(traceback.format_exc())
|
| 656 |
+
print("Error processing text file: " + str(e))
|
| 657 |
+
gr.Warning(f"Failed to process file {os.path.basename(file_obj.name)}. Error: {e}")
|
| 658 |
+
|
| 659 |
+
download = []
|
| 660 |
+
if txt_infos:
|
| 661 |
+
zip_filename = "texts-processed-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip"
|
| 662 |
+
downloadZipPath = os.path.join(output_dir, zip_filename)
|
| 663 |
+
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as processed_zip:
|
| 664 |
+
for info in txt_infos:
|
| 665 |
+
processed_zip.write(info["path"], arcname=info["name"])
|
| 666 |
+
download.append(downloadZipPath)
|
| 667 |
+
|
| 668 |
+
if llama3_reorganize:
|
| 669 |
+
llama3_reorganize.release_vram()
|
| 670 |
+
|
| 671 |
+
progress(1, desc="Text processing completed")
|
| 672 |
timer.report_all() # Print all recorded times
|
| 673 |
+
print("Text processing is complete.")
|
| 674 |
+
|
| 675 |
+
# Return values in the same structure as the image path, with placeholders for unused outputs
|
| 676 |
+
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
| 677 |
|
|
|
|
|
|
|
| 678 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
| 679 |
if not selected_state:
|
| 680 |
return selected_state
|
|
|
|
| 711 |
if not gallery or not selected_image:
|
| 712 |
return gallery
|
| 713 |
|
| 714 |
+
try:
|
| 715 |
+
selected_image = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
|
| 716 |
+
# Remove the selected image from the gallery
|
| 717 |
+
if selected_image in gallery:
|
| 718 |
+
gallery.remove(selected_image)
|
| 719 |
+
except (ValueError, SyntaxError):
|
| 720 |
+
# Handle cases where the string is not a valid literal
|
| 721 |
+
print(f"Warning: Could not parse selected_image string: {selected_image}")
|
| 722 |
+
|
| 723 |
return gallery
|
| 724 |
|
| 725 |
|
|
|
|
| 731 |
width: 55.5% !important;
|
| 732 |
}
|
| 733 |
"""
|
|
|
|
| 734 |
args = parse_args()
|
| 735 |
|
| 736 |
predictor = Predictor()
|
|
|
|
| 751 |
SWINV2_MODEL_IS_DSV1_REPO,
|
| 752 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
| 753 |
]
|
| 754 |
+
|
| 755 |
llama_list = [
|
| 756 |
META_LLAMA_3_3B_REPO,
|
| 757 |
META_LLAMA_3_8B_REPO,
|
| 758 |
]
|
| 759 |
|
| 760 |
+
# NEW: Wrapper function to decide which prediction method to call
|
| 761 |
+
def run_prediction(
|
| 762 |
+
input_type, gallery, text_files, model_repo, general_thresh,
|
| 763 |
+
general_mcut_enabled, character_thresh, character_mcut_enabled,
|
| 764 |
+
characters_merge_enabled, llama3_reorganize_model_repo,
|
| 765 |
+
additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()
|
| 766 |
+
):
|
| 767 |
+
if input_type == 'Image':
|
| 768 |
+
return predictor.predict_from_images(
|
| 769 |
+
gallery, model_repo, general_thresh, general_mcut_enabled,
|
| 770 |
+
character_thresh, character_mcut_enabled, characters_merge_enabled,
|
| 771 |
+
llama3_reorganize_model_repo, additional_tags_prepend,
|
| 772 |
+
additional_tags_append, tag_results, progress
|
| 773 |
+
)
|
| 774 |
+
else: # 'Text file (.txt)'
|
| 775 |
+
# For text files, some parameters are not used, but we must return
|
| 776 |
+
# a tuple of the same size. `predict_from_text` handles this.
|
| 777 |
+
return predictor.predict_from_text(
|
| 778 |
+
text_files, llama3_reorganize_model_repo,
|
| 779 |
+
additional_tags_prepend, additional_tags_append, progress
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
with gr.Blocks(title=TITLE, css=css) as demo:
|
| 783 |
+
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
|
| 784 |
gr.Markdown(value=DESCRIPTION)
|
| 785 |
+
|
| 786 |
with gr.Row():
|
| 787 |
with gr.Column():
|
| 788 |
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
| 789 |
+
|
| 790 |
+
# Input type selector
|
| 791 |
+
input_type_radio = gr.Radio(
|
| 792 |
+
choices=['Image', 'Text file (.txt)'],
|
| 793 |
+
value='Image',
|
| 794 |
+
label="Input Type"
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# Group for image inputs, initially visible
|
| 798 |
+
with gr.Column(visible=True) as image_inputs_group:
|
| 799 |
+
with gr.Column(variant="panel"):
|
| 800 |
+
# Create an Image component for uploading images
|
| 801 |
+
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
|
| 802 |
+
with gr.Row():
|
| 803 |
+
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 804 |
+
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 805 |
+
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 806 |
+
|
| 807 |
+
# NEW: Group for text file inputs, initially hidden
|
| 808 |
+
with gr.Column(visible=False) as text_inputs_group:
|
| 809 |
+
text_files_input = gr.Files(
|
| 810 |
+
label="Upload .txt files",
|
| 811 |
+
file_types=[".txt"],
|
| 812 |
+
file_count="multiple",
|
| 813 |
+
height=500
|
| 814 |
+
)
|
| 815 |
|
| 816 |
+
# NEW: Logic to show/hide input groups based on radio selection
|
| 817 |
+
def change_input_type(input_type):
|
| 818 |
+
is_image = (input_type == 'Image')
|
| 819 |
+
return {
|
| 820 |
+
image_inputs_group: gr.update(visible=is_image),
|
| 821 |
+
text_inputs_group: gr.update(visible=not is_image),
|
| 822 |
+
# Also update visibility of image-specific settings
|
| 823 |
+
model_repo: gr.update(visible=is_image),
|
| 824 |
+
general_thresh_row: gr.update(visible=is_image),
|
| 825 |
+
character_thresh_row: gr.update(visible=is_image),
|
| 826 |
+
characters_merge_enabled: gr.update(visible=is_image),
|
| 827 |
+
categorized: gr.update(visible=is_image),
|
| 828 |
+
rating: gr.update(visible=is_image),
|
| 829 |
+
character_res: gr.update(visible=is_image),
|
| 830 |
+
general_res: gr.update(visible=is_image),
|
| 831 |
+
unclassified: gr.update(visible=is_image),
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
# Image-specific settings
|
| 835 |
model_repo = gr.Dropdown(
|
| 836 |
dropdown_list,
|
| 837 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
| 838 |
+
label="Model (for Images)",
|
| 839 |
)
|
| 840 |
+
with gr.Row(visible=True) as general_thresh_row:
|
| 841 |
general_thresh = gr.Slider(
|
| 842 |
0,
|
| 843 |
1,
|
|
|
|
| 851 |
label="Use MCut threshold",
|
| 852 |
scale=1,
|
| 853 |
)
|
| 854 |
+
with gr.Row(visible=True) as character_thresh_row:
|
| 855 |
character_thresh = gr.Slider(
|
| 856 |
0,
|
| 857 |
1,
|
|
|
|
| 865 |
label="Use MCut threshold",
|
| 866 |
scale=1,
|
| 867 |
)
|
| 868 |
+
characters_merge_enabled = gr.Checkbox(
|
| 869 |
+
value=True,
|
| 870 |
+
label="Merge characters into the string output",
|
| 871 |
+
scale=1,
|
| 872 |
+
visible=True,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
# Common settings
|
| 876 |
with gr.Row():
|
| 877 |
llama3_reorganize_model_repo = gr.Dropdown(
|
| 878 |
[None] + llama_list,
|
| 879 |
value=None,
|
| 880 |
+
label="Use the Llama3 model to reorganize the article",
|
| 881 |
+
info="(Note: very slow)",
|
| 882 |
)
|
| 883 |
with gr.Row():
|
| 884 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
|
|
|
| 887 |
clear = gr.ClearButton(
|
| 888 |
components=[
|
| 889 |
gallery,
|
| 890 |
+
text_files_input,
|
| 891 |
model_repo,
|
| 892 |
general_thresh,
|
| 893 |
general_mcut_enabled,
|
|
|
|
| 901 |
variant="secondary",
|
| 902 |
size="lg",
|
| 903 |
)
|
| 904 |
+
|
| 905 |
with gr.Column(variant="panel"):
|
| 906 |
download_file = gr.File(label="Output (Download)")
|
| 907 |
+
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True, lines=5)
|
| 908 |
+
# Image-specific outputs
|
| 909 |
+
categorized = gr.JSON(label="Categorized (tags)", visible=True)
|
| 910 |
+
rating = gr.Label(label="Rating", visible=True)
|
| 911 |
+
character_res = gr.Label(label="Output (characters)", visible=True)
|
| 912 |
+
general_res = gr.Label(label="Output (tags)", visible=True)
|
| 913 |
+
unclassified = gr.JSON(label="Unclassified (tags)", visible=True)
|
| 914 |
clear.add(
|
| 915 |
[
|
| 916 |
download_file,
|
|
|
|
| 922 |
unclassified,
|
| 923 |
]
|
| 924 |
)
|
| 925 |
+
|
| 926 |
tag_results = gr.State({})
|
| 927 |
+
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 928 |
+
|
| 929 |
+
# Event Listeners
|
| 930 |
# Define the event listener to add the uploaded image to the gallery
|
| 931 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
| 932 |
# When the upload button is clicked, add the new images to the gallery
|
| 933 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
| 934 |
# Event to update the selected image when an image is clicked in the gallery
|
|
|
|
| 935 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
| 936 |
# Event to remove a selected image from the gallery
|
| 937 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
| 938 |
+
|
| 939 |
+
# Connect the radio button to the visibility function
|
| 940 |
+
input_type_radio.change(
|
| 941 |
+
fn=change_input_type,
|
| 942 |
+
inputs=input_type_radio,
|
| 943 |
+
outputs=[
|
| 944 |
+
image_inputs_group, text_inputs_group, model_repo,
|
| 945 |
+
general_thresh_row, character_thresh_row, characters_merge_enabled,
|
| 946 |
+
categorized, rating, character_res, general_res, unclassified
|
| 947 |
+
]
|
| 948 |
+
)
|
| 949 |
|
| 950 |
+
# submit click now calls the wrapper function
|
| 951 |
+
submit.click(
|
| 952 |
+
fn=run_prediction,
|
| 953 |
+
inputs=[
|
| 954 |
+
input_type_radio,
|
| 955 |
+
gallery,
|
| 956 |
+
text_files_input,
|
| 957 |
+
model_repo,
|
| 958 |
+
general_thresh,
|
| 959 |
+
general_mcut_enabled,
|
| 960 |
+
character_thresh,
|
| 961 |
+
character_mcut_enabled,
|
| 962 |
+
characters_merge_enabled,
|
| 963 |
+
llama3_reorganize_model_repo,
|
| 964 |
+
additional_tags_prepend,
|
| 965 |
+
additional_tags_append,
|
| 966 |
+
tag_results,
|
| 967 |
+
],
|
| 968 |
+
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|
| 969 |
+
)
|
| 970 |
|
| 971 |
gr.Examples(
|
| 972 |
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
|
webui.bat
CHANGED
|
@@ -1,21 +1,37 @@
|
|
| 1 |
@echo off
|
| 2 |
|
| 3 |
-
:: The source of the webui.bat file is stable-diffusion-webui
|
| 4 |
-
::
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
if not defined PYTHON (set PYTHON=python)
|
| 7 |
-
if not defined
|
|
|
|
| 8 |
|
| 9 |
mkdir tmp 2>NUL
|
| 10 |
|
|
|
|
| 11 |
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 12 |
if %ERRORLEVEL% == 0 goto :check_pip
|
| 13 |
echo Couldn't launch python
|
| 14 |
goto :show_stdout_stderr
|
| 15 |
|
| 16 |
:check_pip
|
|
|
|
| 17 |
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
|
| 18 |
if %ERRORLEVEL% == 0 goto :start_venv
|
|
|
|
| 19 |
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
|
| 20 |
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 21 |
if %ERRORLEVEL% == 0 goto :start_venv
|
|
@@ -23,33 +39,106 @@ echo Couldn't install pip
|
|
| 23 |
goto :show_stdout_stderr
|
| 24 |
|
| 25 |
:start_venv
|
| 26 |
-
if
|
| 27 |
-
if ["%
|
|
|
|
|
|
|
| 28 |
|
|
|
|
| 29 |
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 30 |
-
if %ERRORLEVEL% == 0 goto :
|
| 31 |
|
|
|
|
|
|
|
| 32 |
for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
|
| 33 |
echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
|
| 34 |
%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 35 |
-
if %ERRORLEVEL%
|
| 36 |
-
echo Unable to create venv in directory "%VENV_DIR%"
|
| 37 |
-
goto :show_stdout_stderr
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
|
| 41 |
-
echo venv %PYTHON%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
goto :launch
|
| 45 |
|
| 46 |
:launch
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
| 48 |
pause
|
| 49 |
exit /b
|
| 50 |
|
| 51 |
-
:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
|
|
|
|
|
|
|
|
|
| 53 |
echo.
|
| 54 |
echo exit code: %errorlevel%
|
| 55 |
|
|
@@ -61,13 +150,13 @@ type tmp\stdout.txt
|
|
| 61 |
|
| 62 |
:show_stderr
|
| 63 |
for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
|
| 64 |
-
if %size% equ 0 goto :
|
| 65 |
echo.
|
| 66 |
echo stderr:
|
| 67 |
type tmp\stderr.txt
|
| 68 |
|
| 69 |
:endofscript
|
| 70 |
-
|
| 71 |
echo.
|
| 72 |
echo Launch unsuccessful. Exiting.
|
| 73 |
pause
|
|
|
|
|
|
| 1 |
@echo off
|
| 2 |
|
| 3 |
+
:: The original source of the webui.bat file is stable-diffusion-webui
|
| 4 |
+
:: Modified and enhanced by Gemini with features for venv management and requirements handling.
|
| 5 |
|
| 6 |
+
:: --------- Configuration ---------
|
| 7 |
+
set COMMANDLINE_ARGS=
|
| 8 |
+
:: Define the name of the Launch application
|
| 9 |
+
set APPLICATION_NAME=app.py
|
| 10 |
+
:: Define the name of the virtual environment directory
|
| 11 |
+
set VENV_NAME=venv
|
| 12 |
+
:: Set to 1 to always attempt to update packages from requirements.txt on every launch
|
| 13 |
+
set ALWAYS_UPDATE_REQS=0
|
| 14 |
+
:: ---------------------------------
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
:: Set PYTHON executable if not already defined
|
| 18 |
if not defined PYTHON (set PYTHON=python)
|
| 19 |
+
:: Set VENV_DIR using VENV_NAME if not already defined
|
| 20 |
+
if not defined VENV_DIR (set "VENV_DIR=%~dp0%VENV_NAME%")
|
| 21 |
|
| 22 |
mkdir tmp 2>NUL
|
| 23 |
|
| 24 |
+
:: Check if Python is callable
|
| 25 |
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 26 |
if %ERRORLEVEL% == 0 goto :check_pip
|
| 27 |
echo Couldn't launch python
|
| 28 |
goto :show_stdout_stderr
|
| 29 |
|
| 30 |
:check_pip
|
| 31 |
+
:: Check if pip is available
|
| 32 |
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
|
| 33 |
if %ERRORLEVEL% == 0 goto :start_venv
|
| 34 |
+
:: If pip is not available and PIP_INSTALLER_LOCATION is set, try to install pip
|
| 35 |
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
|
| 36 |
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 37 |
if %ERRORLEVEL% == 0 goto :start_venv
|
|
|
|
| 39 |
goto :show_stdout_stderr
|
| 40 |
|
| 41 |
:start_venv
|
| 42 |
+
:: Skip venv creation/activation if VENV_DIR is explicitly set to "-"
|
| 43 |
+
if ["%VENV_DIR%"] == ["-"] goto :skip_venv_entirely
|
| 44 |
+
:: Skip venv creation/activation if SKIP_VENV is set to "1"
|
| 45 |
+
if ["%SKIP_VENV%"] == ["1"] goto :skip_venv_entirely
|
| 46 |
|
| 47 |
+
:: Check if the venv already exists by looking for Python.exe in its Scripts directory
|
| 48 |
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 49 |
+
if %ERRORLEVEL% == 0 goto :activate_venv_and_maybe_update
|
| 50 |
|
| 51 |
+
:: Venv does not exist, create it
|
| 52 |
+
echo Virtual environment not found in "%VENV_DIR%". Creating a new one.
|
| 53 |
for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
|
| 54 |
echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
|
| 55 |
%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
|
| 56 |
+
if %ERRORLEVEL% NEQ 0 (
|
| 57 |
+
echo Unable to create venv in directory "%VENV_DIR%"
|
| 58 |
+
goto :show_stdout_stderr
|
| 59 |
+
)
|
| 60 |
+
echo Venv created.
|
| 61 |
+
|
| 62 |
+
:: Install requirements for the first time if venv was just created
|
| 63 |
+
:: This section handles the initial installation of packages from requirements.txt
|
| 64 |
+
:: immediately after a new virtual environment is created.
|
| 65 |
+
echo Checking for requirements.txt for initial setup in %~dp0
|
| 66 |
+
if exist "%~dp0requirements.txt" (
|
| 67 |
+
echo Found requirements.txt, attempting to install for initial setup...
|
| 68 |
+
call "%VENV_DIR%\Scripts\activate.bat"
|
| 69 |
+
echo Installing packages from requirements.txt ^(initial setup^)...
|
| 70 |
+
"%VENV_DIR%\Scripts\python.exe" -m pip install -r "%~dp0requirements.txt"
|
| 71 |
+
if %ERRORLEVEL% NEQ 0 (
|
| 72 |
+
echo Failed to install requirements during initial setup. Please check the output above.
|
| 73 |
+
pause
|
| 74 |
+
goto :show_stdout_stderr_custom_pip_initial
|
| 75 |
+
)
|
| 76 |
+
echo Initial requirements installed successfully.
|
| 77 |
+
call "%VENV_DIR%\Scripts\deactivate.bat"
|
| 78 |
+
) else (
|
| 79 |
+
echo No requirements.txt found for initial setup, skipping package installation.
|
| 80 |
+
)
|
| 81 |
+
goto :activate_venv_and_maybe_update
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
:activate_venv_and_maybe_update
|
| 85 |
+
:: This label is reached if the venv exists or was just created.
|
| 86 |
+
:: Set PYTHON to point to the venv's Python interpreter.
|
| 87 |
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
|
| 88 |
+
echo Activating venv: %PYTHON%
|
| 89 |
+
|
| 90 |
+
:: Always update requirements if ALWAYS_UPDATE_REQS is 1
|
| 91 |
+
:: This section allows for updating packages from requirements.txt on every launch
|
| 92 |
+
:: if the ALWAYS_UPDATE_REQS variable is set to 1.
|
| 93 |
+
if defined ALWAYS_UPDATE_REQS (
|
| 94 |
+
if "%ALWAYS_UPDATE_REQS%"=="1" (
|
| 95 |
+
echo ALWAYS_UPDATE_REQS is enabled.
|
| 96 |
+
if exist "%~dp0requirements.txt" (
|
| 97 |
+
echo Attempting to update packages from requirements.txt...
|
| 98 |
+
REM No need to call activate.bat here again, PYTHON is already set to the venv's python
|
| 99 |
+
%PYTHON% -m pip install -r "%~dp0requirements.txt"
|
| 100 |
+
if %ERRORLEVEL% NEQ 0 (
|
| 101 |
+
echo Failed to update requirements. Please check the output above.
|
| 102 |
+
pause
|
| 103 |
+
goto :endofscript
|
| 104 |
+
)
|
| 105 |
+
echo Requirements updated successfully.
|
| 106 |
+
) else (
|
| 107 |
+
echo ALWAYS_UPDATE_REQS is enabled, but no requirements.txt found. Skipping update.
|
| 108 |
+
)
|
| 109 |
+
) else (
|
| 110 |
+
echo ALWAYS_UPDATE_REQS is not enabled or not set to 1. Skipping routine update.
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
|
| 114 |
+
goto :launch
|
| 115 |
+
|
| 116 |
+
:skip_venv_entirely
|
| 117 |
+
:: This label is reached if venv usage is explicitly skipped.
|
| 118 |
+
echo Skipping venv.
|
| 119 |
goto :launch
|
| 120 |
|
| 121 |
:launch
|
| 122 |
+
:: Launch the main application
|
| 123 |
+
echo Launching Web UI with arguments: %COMMANDLINE_ARGS% %*
|
| 124 |
+
%PYTHON% %APPLICATION_NAME% %COMMANDLINE_ARGS% %*
|
| 125 |
+
echo Launch finished.
|
| 126 |
pause
|
| 127 |
exit /b
|
| 128 |
|
| 129 |
+
:show_stdout_stderr_custom_pip_initial
|
| 130 |
+
:: Custom error handler for failures during the initial pip install process.
|
| 131 |
+
echo.
|
| 132 |
+
echo exit code ^(pip initial install^): %errorlevel%
|
| 133 |
+
echo Errors during initial pip install. See output above.
|
| 134 |
+
echo.
|
| 135 |
+
echo Launch unsuccessful. Exiting.
|
| 136 |
+
pause
|
| 137 |
+
exit /b
|
| 138 |
|
| 139 |
+
|
| 140 |
+
:show_stdout_stderr
|
| 141 |
+
:: General error handler: displays stdout and stderr from the tmp directory.
|
| 142 |
echo.
|
| 143 |
echo exit code: %errorlevel%
|
| 144 |
|
|
|
|
| 150 |
|
| 151 |
:show_stderr
|
| 152 |
for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
|
| 153 |
+
if %size% equ 0 goto :endofscript
|
| 154 |
echo.
|
| 155 |
echo stderr:
|
| 156 |
type tmp\stderr.txt
|
| 157 |
|
| 158 |
:endofscript
|
|
|
|
| 159 |
echo.
|
| 160 |
echo Launch unsuccessful. Exiting.
|
| 161 |
pause
|
| 162 |
+
exit /b
|