Spaces:
Sleeping
Sleeping
Commit
·
b3dd22d
verified
·
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse files- .gitattributes +35 -0
- README.md +13 -0
- app.py +140 -0
- character_series_dict.csv +0 -0
- danbooru_e621.csv +0 -0
- myt2tmod.py +116 -0
- originalt2t.py +127 -0
- output.py +16 -0
- requirements.txt +11 -0
- t2t.py +41 -0
- t2tmod.py +117 -0
- tag_group.csv +0 -0
- tagger.py +450 -0
- tags.txt +0 -0
- utils.py +45 -0
- v2.py +214 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Natural Text to Danbooru Tags with Transformer V2
|
3 |
+
emoji: 👀📦
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: openrail
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from v2 import (
|
4 |
+
V2UI,
|
5 |
+
parse_upsampling_output,
|
6 |
+
V2_ALL_MODELS,
|
7 |
+
)
|
8 |
+
from utils import (
|
9 |
+
gradio_copy_text,
|
10 |
+
COPY_ACTION_JS,
|
11 |
+
V2_ASPECT_RATIO_OPTIONS,
|
12 |
+
V2_RATING_OPTIONS,
|
13 |
+
V2_LENGTH_OPTIONS,
|
14 |
+
V2_IDENTITY_OPTIONS
|
15 |
+
)
|
16 |
+
from tagger import (
|
17 |
+
predict_tags,
|
18 |
+
convert_danbooru_to_e621_prompt,
|
19 |
+
remove_specific_prompt,
|
20 |
+
insert_recom_prompt,
|
21 |
+
compose_prompt_to_copy,
|
22 |
+
translate_prompt,
|
23 |
+
)
|
24 |
+
from t2t import predict_text_to_tags
|
25 |
+
|
26 |
+
|
27 |
+
def description_ui():
|
28 |
+
gr.Markdown(
|
29 |
+
"""
|
30 |
+
## Natural Text to Danbooru Tags with Danbooru Tags Transformer V2
|
31 |
+
Natural text => Prompt => Upsampled longer prompt
|
32 |
+
- Mod of [ooferdoodles/text2tags-demo](https://huggingface.co/spaces/ooferdoodles/text2tags-demo) and p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2)
|
33 |
+
- It's buggy but seems to work for now.
|
34 |
+
"""
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def main():
|
39 |
+
|
40 |
+
v2 = V2UI()
|
41 |
+
|
42 |
+
with gr.Blocks() as ui:
|
43 |
+
description_ui()
|
44 |
+
|
45 |
+
with gr.Row():
|
46 |
+
with gr.Column(scale=2):
|
47 |
+
with gr.Group():
|
48 |
+
t2t_text = gr.TextArea(label="Natural text", lines=6, placeholder="Minato Aqua from hololive with pink and blue twintails in a blue maid outfit ...", value="", show_copy_button=True)
|
49 |
+
with gr.Accordion(label="Advanced options", open=False):
|
50 |
+
translate_t2t_text_button = gr.Button(value="Translate text to English", size="sm", variant="secondary")
|
51 |
+
t2t_max_tokens = gr.Slider(0, 256, step=16, value=128, label='max_tokens')
|
52 |
+
t2t_temperature = gr.Slider(0.001, 2, step=0.1, value=0.7, label='temperature')
|
53 |
+
t2t_top_k = gr.Slider(0, 100, step=5, value=20, label='top_k')
|
54 |
+
t2t_top_p = gr.Slider(0, 2, step=0.05, value=0.95, label='top_p')
|
55 |
+
t2t_repeat_penalty = gr.Slider(0, 5, step=0.1, value=1.1, label='repeat_penalty')
|
56 |
+
t2t_examples = gr.Examples([
|
57 |
+
["Minato Aqua from hololive with pink and blue twintails in a blue maid outfit"],
|
58 |
+
],
|
59 |
+
t2t_text,
|
60 |
+
cache_examples=False,
|
61 |
+
)
|
62 |
+
generate_from_text_btn = gr.Button(value="GENERATE TAGS FROM TEXT", size="lg", variant="primary")
|
63 |
+
|
64 |
+
with gr.Group():
|
65 |
+
input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku", visible=False)
|
66 |
+
input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid", visible=False)
|
67 |
+
input_general = gr.TextArea(label="General tags", lines=6, placeholder="1girl, ...", value="", show_copy_button=True)
|
68 |
+
input_tags_to_copy = gr.Textbox(value="", visible=False)
|
69 |
+
copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False, visible=False)
|
70 |
+
tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
|
71 |
+
input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
|
72 |
+
with gr.Accordion(label="Advanced options", open=False):
|
73 |
+
input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square")
|
74 |
+
input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long")
|
75 |
+
input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
|
76 |
+
input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
|
77 |
+
model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
|
78 |
+
dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
|
79 |
+
recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
|
80 |
+
recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
|
81 |
+
|
82 |
+
generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
|
83 |
+
|
84 |
+
with gr.Group():
|
85 |
+
output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
|
86 |
+
copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
87 |
+
elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
|
88 |
+
|
89 |
+
with gr.Group():
|
90 |
+
output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
|
91 |
+
copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
92 |
+
|
93 |
+
v2.input_components = [
|
94 |
+
model_name,
|
95 |
+
input_copyright,
|
96 |
+
input_character,
|
97 |
+
input_general,
|
98 |
+
input_rating,
|
99 |
+
input_aspect_ratio,
|
100 |
+
input_length,
|
101 |
+
input_identity,
|
102 |
+
input_ban_tags,
|
103 |
+
]
|
104 |
+
|
105 |
+
translate_t2t_text_button.click(translate_prompt, inputs=[t2t_text], outputs=[t2t_text])
|
106 |
+
|
107 |
+
generate_from_text_btn.click(
|
108 |
+
predict_text_to_tags,
|
109 |
+
inputs=[t2t_text, t2t_max_tokens, t2t_temperature, t2t_top_k, t2t_top_p, t2t_repeat_penalty],
|
110 |
+
outputs=[
|
111 |
+
input_general,
|
112 |
+
],
|
113 |
+
)
|
114 |
+
|
115 |
+
copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
|
116 |
+
gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
|
117 |
+
)
|
118 |
+
|
119 |
+
generate_btn.click(
|
120 |
+
parse_upsampling_output(v2.on_generate),
|
121 |
+
inputs=[
|
122 |
+
*v2.input_components,
|
123 |
+
],
|
124 |
+
outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
|
125 |
+
).then(
|
126 |
+
convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
|
127 |
+
).then(
|
128 |
+
insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
|
129 |
+
).then(
|
130 |
+
insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
|
131 |
+
)
|
132 |
+
copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
|
133 |
+
copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
|
134 |
+
|
135 |
+
ui.launch()
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
140 |
+
|
character_series_dict.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
danbooru_e621.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
myt2tmod.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
import editdistance
|
7 |
+
|
8 |
+
|
9 |
+
class TaggerLlama:
|
10 |
+
MODEL_URL = "John6666/llama-tagger-HF-GPTQ-4bits"
|
11 |
+
SAVE_NAME = "llama-tagger-HF-GPTQ-4bits"
|
12 |
+
TAGS_FILE_NAME = "tags.txt"
|
13 |
+
model = None
|
14 |
+
tokenizer = None
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
):
|
19 |
+
self.download_model()
|
20 |
+
self.tag_list = self.load_tags()
|
21 |
+
|
22 |
+
def download_model(
|
23 |
+
self,
|
24 |
+
model_url=None,
|
25 |
+
save_name=None,
|
26 |
+
):
|
27 |
+
model_url = model_url or self.MODEL_URL # Use self.MODEL_URL
|
28 |
+
save_name = save_name or self.SAVE_NAME
|
29 |
+
|
30 |
+
save_path = os.path.join(tempfile.gettempdir(), save_name)
|
31 |
+
if os.path.exists(save_path):
|
32 |
+
print("Model already exists. Skipping download.")
|
33 |
+
return
|
34 |
+
print("Downloading Model")
|
35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_URL)
|
36 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_URL, device_map="cuda:0")
|
37 |
+
self.tokenizer.save_pretrained(self.SAVE_NAME)
|
38 |
+
self.model.save_pretrained(self.SAVE_NAME)
|
39 |
+
print("Model Downloaded")
|
40 |
+
|
41 |
+
def load_tags(self):
|
42 |
+
module_path = os.path.abspath(__file__)
|
43 |
+
lookups_dir = os.path.join(os.path.dirname(module_path), "tags.txt")
|
44 |
+
try:
|
45 |
+
tags_file = lookups_dir
|
46 |
+
with open(tags_file, "r") as f:
|
47 |
+
tag_dict = [line.strip() for line in f]
|
48 |
+
return tag_dict
|
49 |
+
except IOError as e:
|
50 |
+
print(f"Error loading tag dictionary: {e}")
|
51 |
+
return []
|
52 |
+
|
53 |
+
def preprocess_tag(self, tag):
|
54 |
+
tag = tag.lower()
|
55 |
+
match = re.match(r"^([^()]*\([^()]*\))\s*.*$", tag)
|
56 |
+
return match.group(1) if match else tag
|
57 |
+
|
58 |
+
def find_closest_tag(self, tag, threshold, tag_list, cache={}):
|
59 |
+
if tag in cache:
|
60 |
+
return cache[tag]
|
61 |
+
|
62 |
+
closest_tag = min(tag_list, key=lambda x: editdistance.eval(tag, x))
|
63 |
+
if editdistance.eval(tag, closest_tag) <= threshold:
|
64 |
+
cache[tag] = closest_tag
|
65 |
+
return closest_tag
|
66 |
+
else:
|
67 |
+
return None
|
68 |
+
|
69 |
+
def correct_tags(self, tags, tag_list, preprocess=True):
|
70 |
+
if preprocess:
|
71 |
+
tags = (self.preprocess_tag(x) for x in tags)
|
72 |
+
corrected_tags = set()
|
73 |
+
for tag in tags:
|
74 |
+
threshold = max(1, len(tag) - 10)
|
75 |
+
closest_tag = self.find_closest_tag(tag, threshold, tag_list)
|
76 |
+
if closest_tag:
|
77 |
+
corrected_tags.add(closest_tag)
|
78 |
+
return sorted(list(corrected_tags))
|
79 |
+
|
80 |
+
def predict_tags(
|
81 |
+
self,
|
82 |
+
prompt: str,
|
83 |
+
max_tokens: int = 128,
|
84 |
+
temperature: float = 0.8,
|
85 |
+
top_p: float = 0.95,
|
86 |
+
repeat_penalty: float = 1.1,
|
87 |
+
top_k: int = 40,
|
88 |
+
):
|
89 |
+
prompt = f"### Caption:{prompt}\n### Tags:"
|
90 |
+
|
91 |
+
input_ids = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
|
92 |
+
terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("\n"),
|
93 |
+
self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("### Tags:")]
|
94 |
+
|
95 |
+
raw_output = self.model.generate(
|
96 |
+
input_ids.to(self.model.device),
|
97 |
+
tokenizer=self.tokenizer,
|
98 |
+
max_new_tokens=max_tokens,
|
99 |
+
temperature=temperature,
|
100 |
+
top_p=top_p,
|
101 |
+
repetition_penalty=repeat_penalty,
|
102 |
+
top_k=top_k,
|
103 |
+
do_sample=True,
|
104 |
+
stop_strings=["\n", "### Tags:"],
|
105 |
+
eos_token_id=terminators,
|
106 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
107 |
+
)
|
108 |
+
output = self.tokenizer.batch_decode(raw_output, skip_special_tokens=True)
|
109 |
+
|
110 |
+
raw_preds = re.sub('^.+\n### Tags:(.+?$)', '\\1', output[0])
|
111 |
+
pred_tags = [x.strip() for x in raw_preds.split(",")]
|
112 |
+
corrected_tags = self.correct_tags(pred_tags, self.tag_list)
|
113 |
+
return corrected_tags
|
114 |
+
|
115 |
+
# https://github.com/ooferdoodles1337/text2tags-lib
|
116 |
+
# https://huggingface.co/docs/transformers/main_classes/text_generation
|
originalt2t.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
import tempfile
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import wget
|
7 |
+
import editdistance
|
8 |
+
from llama_cpp import Llama
|
9 |
+
|
10 |
+
|
11 |
+
class TaggerLlama(Llama):
|
12 |
+
MODEL_URL = "https://huggingface.co/ooferdoodles/llama-tagger-7b/resolve/main/llama-tagger.gguf?download=true"
|
13 |
+
SAVE_NAME = "llama-tagger.gguf"
|
14 |
+
TAGS_FILE_NAME = "tags.txt"
|
15 |
+
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_path: str = None,
|
20 |
+
**kwargs,
|
21 |
+
):
|
22 |
+
if model_path is None:
|
23 |
+
model_path = os.path.join(tempfile.gettempdir(), self.SAVE_NAME)
|
24 |
+
self.download_model()
|
25 |
+
super().__init__(model_path, **kwargs)
|
26 |
+
self.tag_list = self.load_tags()
|
27 |
+
|
28 |
+
def download_model(
|
29 |
+
self,
|
30 |
+
model_url=None,
|
31 |
+
save_name=None,
|
32 |
+
):
|
33 |
+
model_url = model_url or self.MODEL_URL # Use self.MODEL_URL
|
34 |
+
save_name = save_name or self.SAVE_NAME
|
35 |
+
|
36 |
+
save_path = os.path.join(tempfile.gettempdir(), save_name)
|
37 |
+
if os.path.exists(save_path):
|
38 |
+
print("Model already exists. Skipping download.")
|
39 |
+
return
|
40 |
+
print("Downloading Model")
|
41 |
+
wget.download(model_url, out=save_path)
|
42 |
+
print("Model Downloaded")
|
43 |
+
|
44 |
+
def load_tags(self):
|
45 |
+
module_path = os.path.abspath(__file__)
|
46 |
+
lookups_dir = os.path.join(os.path.dirname(module_path), "tags.txt")
|
47 |
+
try:
|
48 |
+
tags_file = lookups_dir
|
49 |
+
with open(tags_file, "r") as f:
|
50 |
+
tag_dict = [line.strip() for line in f]
|
51 |
+
return tag_dict
|
52 |
+
except IOError as e:
|
53 |
+
print(f"Error loading tag dictionary: {e}")
|
54 |
+
return []
|
55 |
+
|
56 |
+
def preprocess_tag(self, tag):
|
57 |
+
tag = tag.lower()
|
58 |
+
match = re.match(r"^([^()]*\([^()]*\))\s*.*$", tag)
|
59 |
+
return match.group(1) if match else tag
|
60 |
+
|
61 |
+
def find_closest_tag(self, tag, threshold, tag_list, cache={}):
|
62 |
+
if tag in cache:
|
63 |
+
return cache[tag]
|
64 |
+
|
65 |
+
closest_tag = min(tag_list, key=lambda x: editdistance.eval(tag, x))
|
66 |
+
if editdistance.eval(tag, closest_tag) <= threshold:
|
67 |
+
cache[tag] = closest_tag
|
68 |
+
return closest_tag
|
69 |
+
else:
|
70 |
+
return None
|
71 |
+
|
72 |
+
def correct_tags(self, tags, tag_list, preprocess=True):
|
73 |
+
if preprocess:
|
74 |
+
tags = (self.preprocess_tag(x) for x in tags)
|
75 |
+
corrected_tags = set()
|
76 |
+
for tag in tags:
|
77 |
+
threshold = max(1, len(tag) - 10)
|
78 |
+
closest_tag = self.find_closest_tag(tag, threshold, tag_list)
|
79 |
+
if closest_tag:
|
80 |
+
corrected_tags.add(closest_tag)
|
81 |
+
return sorted(list(corrected_tags))
|
82 |
+
|
83 |
+
def predict_tags(
|
84 |
+
self,
|
85 |
+
prompt: str,
|
86 |
+
suffix: Optional[str] = None,
|
87 |
+
max_tokens: int = 128,
|
88 |
+
temperature: float = 0.8,
|
89 |
+
top_p: float = 0.95,
|
90 |
+
logprobs: Optional[int] = None,
|
91 |
+
echo: bool = False,
|
92 |
+
stop: Optional[List[str]] = ["/n", "### Tags:"],
|
93 |
+
frequency_penalty: float = 0.0,
|
94 |
+
presence_penalty: float = 0.0,
|
95 |
+
repeat_penalty: float = 1.1,
|
96 |
+
top_k: int = 40,
|
97 |
+
stream: bool = False,
|
98 |
+
tfs_z: float = 1.0,
|
99 |
+
mirostat_mode: int = 0,
|
100 |
+
mirostat_tau: float = 5.0,
|
101 |
+
mirostat_eta: float = 0.1,
|
102 |
+
):
|
103 |
+
prompt = f"### Caption: {prompt}\n### Tags: "
|
104 |
+
|
105 |
+
output = self.create_completion(
|
106 |
+
prompt=prompt,
|
107 |
+
suffix=suffix,
|
108 |
+
max_tokens=max_tokens,
|
109 |
+
temperature=temperature,
|
110 |
+
top_p=top_p,
|
111 |
+
logprobs=logprobs,
|
112 |
+
echo=echo,
|
113 |
+
stop=stop,
|
114 |
+
frequency_penalty=frequency_penalty,
|
115 |
+
presence_penalty=presence_penalty,
|
116 |
+
repeat_penalty=repeat_penalty,
|
117 |
+
top_k=top_k,
|
118 |
+
stream=stream,
|
119 |
+
tfs_z=tfs_z,
|
120 |
+
mirostat_mode=mirostat_mode,
|
121 |
+
mirostat_tau=mirostat_tau,
|
122 |
+
mirostat_eta=mirostat_eta,
|
123 |
+
)
|
124 |
+
raw_preds = output["choices"][0]["text"]
|
125 |
+
pred_tags = [x.strip() for x in raw_preds.split(",")]
|
126 |
+
corrected_tags = self.correct_tags(pred_tags, self.tag_list)
|
127 |
+
return corrected_tags
|
output.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class UpsamplingOutput:
|
6 |
+
upsampled_tags: str
|
7 |
+
|
8 |
+
copyright_tags: str
|
9 |
+
character_tags: str
|
10 |
+
general_tags: str
|
11 |
+
rating_tag: str
|
12 |
+
aspect_ratio_tag: str
|
13 |
+
length_tag: str
|
14 |
+
identity_tag: str
|
15 |
+
|
16 |
+
elapsed_time: float = 0.0
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
editdistance
|
2 |
+
transformers
|
3 |
+
accelerate
|
4 |
+
sentencepiece
|
5 |
+
auto-gptq
|
6 |
+
optimum
|
7 |
+
httpx==0.13.3
|
8 |
+
httpcore
|
9 |
+
googletrans==4.0.0rc1
|
10 |
+
optimum[onnxruntime]
|
11 |
+
dartrs
|
t2t.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from t2tmod import TaggerLlama
|
2 |
+
import spaces
|
3 |
+
|
4 |
+
|
5 |
+
def translate_text(text = ""):
|
6 |
+
def translate_to_english(prompt):
|
7 |
+
import httpcore
|
8 |
+
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
|
9 |
+
from googletrans import Translator
|
10 |
+
translator = Translator()
|
11 |
+
try:
|
12 |
+
translated_text = translator.translate(text, src='auto', dest='en').text
|
13 |
+
return translated_text
|
14 |
+
except Exception as e:
|
15 |
+
return text
|
16 |
+
|
17 |
+
def is_japanese(s):
|
18 |
+
import unicodedata
|
19 |
+
for ch in s:
|
20 |
+
name = unicodedata.name(ch, "")
|
21 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
22 |
+
return True
|
23 |
+
return False
|
24 |
+
|
25 |
+
return translate_to_english(text) if is_japanese(text) else text
|
26 |
+
|
27 |
+
|
28 |
+
t2t_model = TaggerLlama()
|
29 |
+
|
30 |
+
@spaces.GPU()
|
31 |
+
def predict_text_to_tags(input_text: str="", max_tokens: int=128, temperature: float=0.8, top_k: int=40, top_p: float=0.95, repeat_penalty: float=1.1):
|
32 |
+
text = translate_text(input_text)
|
33 |
+
tags = t2t_model.predict_tags(text, max_tokens=max_tokens, temperature=temperature,
|
34 |
+
top_k=top_k, top_p=top_p, repeat_penalty=repeat_penalty)
|
35 |
+
|
36 |
+
if text != input_text:
|
37 |
+
output = text + ', ' + ', '.join(tags).replace("_", " ")
|
38 |
+
else:
|
39 |
+
output = ', '.join(tags).replace("_", " ")
|
40 |
+
|
41 |
+
return output
|
t2tmod.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
import editdistance
|
7 |
+
|
8 |
+
|
9 |
+
class TaggerLlama:
|
10 |
+
MODEL_URL = "John6666/llama-tagger-HF-GPTQ-4bits"
|
11 |
+
SAVE_NAME = "llama-tagger-HF-GPTQ-4bits"
|
12 |
+
TAGS_FILE_NAME = "tags.txt"
|
13 |
+
model = None
|
14 |
+
tokenizer = None
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
):
|
19 |
+
self.download_model()
|
20 |
+
self.tag_list = self.load_tags()
|
21 |
+
|
22 |
+
def download_model(
|
23 |
+
self,
|
24 |
+
model_url=None,
|
25 |
+
save_name=None,
|
26 |
+
):
|
27 |
+
model_url = model_url or self.MODEL_URL # Use self.MODEL_URL
|
28 |
+
save_name = save_name or self.SAVE_NAME
|
29 |
+
|
30 |
+
save_path = os.path.join(tempfile.gettempdir(), save_name)
|
31 |
+
if os.path.exists(save_path):
|
32 |
+
print("Model already exists. Skipping download.")
|
33 |
+
return
|
34 |
+
print("Downloading Model")
|
35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_URL)
|
36 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_URL, device_map="cuda:0")
|
37 |
+
self.tokenizer.save_pretrained(self.SAVE_NAME)
|
38 |
+
self.model.save_pretrained(self.SAVE_NAME)
|
39 |
+
print("Model Downloaded")
|
40 |
+
|
41 |
+
def load_tags(self):
|
42 |
+
module_path = os.path.abspath(__file__)
|
43 |
+
lookups_dir = os.path.join(os.path.dirname(module_path), "tags.txt")
|
44 |
+
try:
|
45 |
+
tags_file = lookups_dir
|
46 |
+
with open(tags_file, "r") as f:
|
47 |
+
tag_dict = [line.strip() for line in f]
|
48 |
+
return tag_dict
|
49 |
+
except IOError as e:
|
50 |
+
print(f"Error loading tag dictionary: {e}")
|
51 |
+
return []
|
52 |
+
|
53 |
+
def preprocess_tag(self, tag):
|
54 |
+
tag = tag.lower()
|
55 |
+
match = re.match(r"^([^()]*\([^()]*\))\s*.*$", tag)
|
56 |
+
return match.group(1) if match else tag
|
57 |
+
|
58 |
+
def find_closest_tag(self, tag, threshold, tag_list, cache={}):
|
59 |
+
if tag in cache:
|
60 |
+
return cache[tag]
|
61 |
+
|
62 |
+
closest_tag = min(tag_list, key=lambda x: editdistance.eval(tag, x))
|
63 |
+
if editdistance.eval(tag, closest_tag) <= threshold:
|
64 |
+
cache[tag] = closest_tag
|
65 |
+
return closest_tag
|
66 |
+
else:
|
67 |
+
return None
|
68 |
+
|
69 |
+
def correct_tags(self, tags, tag_list, preprocess=True):
|
70 |
+
if preprocess:
|
71 |
+
tags = (self.preprocess_tag(x) for x in tags)
|
72 |
+
corrected_tags = set()
|
73 |
+
for tag in tags:
|
74 |
+
threshold = max(1, len(tag) - 10)
|
75 |
+
closest_tag = self.find_closest_tag(tag, threshold, tag_list)
|
76 |
+
if closest_tag:
|
77 |
+
corrected_tags.add(closest_tag)
|
78 |
+
return sorted(list(corrected_tags))
|
79 |
+
|
80 |
+
def predict_tags(
|
81 |
+
self,
|
82 |
+
prompt: str,
|
83 |
+
max_tokens: int = 128,
|
84 |
+
temperature: float = 0.8,
|
85 |
+
top_p: float = 0.95,
|
86 |
+
repeat_penalty: float = 1.1,
|
87 |
+
top_k: int = 40,
|
88 |
+
):
|
89 |
+
prompt = f"### Caption:{prompt.strip()}\n### Tags:"
|
90 |
+
|
91 |
+
input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
|
92 |
+
terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("\n"),
|
93 |
+
self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("### Tags:")]
|
94 |
+
|
95 |
+
raw_output = self.model.generate(
|
96 |
+
input_ids.to(self.model.device),
|
97 |
+
tokenizer=self.tokenizer,
|
98 |
+
max_new_tokens=max_tokens,
|
99 |
+
temperature=temperature,
|
100 |
+
top_p=top_p,
|
101 |
+
repetition_penalty=repeat_penalty,
|
102 |
+
top_k=top_k,
|
103 |
+
do_sample=True,
|
104 |
+
stop_strings=["\n", "### Tags:"],
|
105 |
+
eos_token_id=terminators,
|
106 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
107 |
+
)
|
108 |
+
output = self.tokenizer.batch_decode(raw_output, skip_special_tokens=True)
|
109 |
+
|
110 |
+
raw_preds = re.sub('^.+\n### Tags:(.+?$)', '\\1', output[0])
|
111 |
+
raw_preds = ",".join(raw_preds.split(",")[2:-1]) if raw_preds.split(",")[0].strip() == "1boy" else raw_preds # to avoid a mysterious bug
|
112 |
+
pred_tags = [x.strip() for x in raw_preds.split(",")]
|
113 |
+
corrected_tags = self.correct_tags(pred_tags, self.tag_list)
|
114 |
+
return corrected_tags
|
115 |
+
|
116 |
+
# https://github.com/ooferdoodles1337/text2tags-lib
|
117 |
+
# https://huggingface.co/docs/transformers/main_classes/text_generation
|
tag_group.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tagger.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import spaces # ZERO GPU
|
5 |
+
|
6 |
+
from transformers import (
|
7 |
+
AutoImageProcessor,
|
8 |
+
AutoModelForImageClassification,
|
9 |
+
)
|
10 |
+
|
11 |
+
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
12 |
+
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
13 |
+
|
14 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
15 |
+
wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
17 |
+
|
18 |
+
|
19 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
20 |
+
return (
|
21 |
+
[f"1{noun}"]
|
22 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
23 |
+
+ [f"{maximum+1}+{noun}s"]
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
PEOPLE_TAGS = (
|
28 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
RATING_MAP = {
|
33 |
+
"general": "safe",
|
34 |
+
"sensitive": "sensitive",
|
35 |
+
"questionable": "nsfw",
|
36 |
+
"explicit": "explicit, nsfw",
|
37 |
+
}
|
38 |
+
DANBOORU_TO_E621_RATING_MAP = {
|
39 |
+
"safe": "rating_safe",
|
40 |
+
"sensitive": "rating_safe",
|
41 |
+
"nsfw": "rating_explicit",
|
42 |
+
"explicit, nsfw": "rating_explicit",
|
43 |
+
"explicit": "rating_explicit",
|
44 |
+
"rating:safe": "rating_safe",
|
45 |
+
"rating:general": "rating_safe",
|
46 |
+
"rating:sensitive": "rating_safe",
|
47 |
+
"rating:questionable, nsfw": "rating_explicit",
|
48 |
+
"rating:explicit, nsfw": "rating_explicit",
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def load_dict_from_csv(filename):
|
53 |
+
with open(filename, 'r', encoding="utf-8") as f:
|
54 |
+
lines = f.readlines()
|
55 |
+
dict = {}
|
56 |
+
for line in lines:
|
57 |
+
parts = line.strip().split(',')
|
58 |
+
dict[parts[0]] = parts[1]
|
59 |
+
return dict
|
60 |
+
|
61 |
+
|
62 |
+
anime_series_dict = load_dict_from_csv('character_series_dict.csv')
|
63 |
+
|
64 |
+
|
65 |
+
def character_list_to_series_list(character_list):
|
66 |
+
output_series_tag = []
|
67 |
+
series_tag = ""
|
68 |
+
series_dict = anime_series_dict
|
69 |
+
for tag in character_list:
|
70 |
+
series_tag = series_dict.get(tag, "")
|
71 |
+
if tag.endswith(")"):
|
72 |
+
tags = tag.split("(")
|
73 |
+
character_tag = "(".join(tags[:-1])
|
74 |
+
if character_tag.endswith(" "):
|
75 |
+
character_tag = character_tag[:-1]
|
76 |
+
series_tag = tags[-1].replace(")", "")
|
77 |
+
|
78 |
+
if series_tag:
|
79 |
+
output_series_tag.append(series_tag)
|
80 |
+
|
81 |
+
return output_series_tag
|
82 |
+
|
83 |
+
|
84 |
+
def danbooru_to_e621(dtag, e621_dict):
|
85 |
+
def d_to_e(match, e621_dict):
|
86 |
+
dtag = match.group(0)
|
87 |
+
etag = e621_dict.get(dtag.strip().replace("_", " "), "")
|
88 |
+
if etag:
|
89 |
+
return etag
|
90 |
+
else:
|
91 |
+
return dtag
|
92 |
+
|
93 |
+
import re
|
94 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
|
95 |
+
|
96 |
+
return tag
|
97 |
+
|
98 |
+
|
99 |
+
danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
|
100 |
+
|
101 |
+
|
102 |
+
def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
|
103 |
+
if prompt_type == "danbooru": return input_prompt
|
104 |
+
tags = input_prompt.split(",") if input_prompt else []
|
105 |
+
people_tags: list[str] = []
|
106 |
+
other_tags: list[str] = []
|
107 |
+
rating_tags: list[str] = []
|
108 |
+
|
109 |
+
e621_dict = danbooru_to_e621_dict
|
110 |
+
for tag in tags:
|
111 |
+
tag = tag.strip().replace("_", " ")
|
112 |
+
tag = danbooru_to_e621(tag, e621_dict)
|
113 |
+
if tag in PEOPLE_TAGS:
|
114 |
+
people_tags.append(tag)
|
115 |
+
elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
|
116 |
+
rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
|
117 |
+
else:
|
118 |
+
other_tags.append(tag)
|
119 |
+
|
120 |
+
rating_tags = sorted(set(rating_tags), key=rating_tags.index)
|
121 |
+
rating_tags = [rating_tags[0]] if rating_tags else []
|
122 |
+
rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
|
123 |
+
|
124 |
+
output_prompt = ", ".join(people_tags + other_tags + rating_tags)
|
125 |
+
|
126 |
+
return output_prompt
|
127 |
+
|
128 |
+
|
129 |
+
def translate_prompt(prompt: str = ""):
|
130 |
+
def translate_to_english(prompt):
|
131 |
+
import httpcore
|
132 |
+
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
|
133 |
+
from googletrans import Translator
|
134 |
+
translator = Translator()
|
135 |
+
try:
|
136 |
+
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
137 |
+
return translated_prompt
|
138 |
+
except Exception as e:
|
139 |
+
return prompt
|
140 |
+
|
141 |
+
def is_japanese(s):
|
142 |
+
import unicodedata
|
143 |
+
for ch in s:
|
144 |
+
name = unicodedata.name(ch, "")
|
145 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
146 |
+
return True
|
147 |
+
return False
|
148 |
+
|
149 |
+
def to_list(s):
|
150 |
+
return [x.strip() for x in s.split(",")]
|
151 |
+
|
152 |
+
prompts = to_list(prompt)
|
153 |
+
outputs = []
|
154 |
+
for p in prompts:
|
155 |
+
p = translate_to_english(p) if is_japanese(p) else p
|
156 |
+
outputs.append(p)
|
157 |
+
|
158 |
+
return ", ".join(outputs)
|
159 |
+
|
160 |
+
|
161 |
+
def translate_prompt_to_ja(prompt: str = ""):
|
162 |
+
def translate_to_japanese(prompt):
|
163 |
+
import httpcore
|
164 |
+
setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
|
165 |
+
from googletrans import Translator
|
166 |
+
translator = Translator()
|
167 |
+
try:
|
168 |
+
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
169 |
+
return translated_prompt
|
170 |
+
except Exception as e:
|
171 |
+
return prompt
|
172 |
+
|
173 |
+
def is_japanese(s):
|
174 |
+
import unicodedata
|
175 |
+
for ch in s:
|
176 |
+
name = unicodedata.name(ch, "")
|
177 |
+
if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
|
178 |
+
return True
|
179 |
+
return False
|
180 |
+
|
181 |
+
def to_list(s):
|
182 |
+
return [x.strip() for x in s.split(",")]
|
183 |
+
|
184 |
+
prompts = to_list(prompt)
|
185 |
+
outputs = []
|
186 |
+
for p in prompts:
|
187 |
+
p = translate_to_japanese(p) if not is_japanese(p) else p
|
188 |
+
outputs.append(p)
|
189 |
+
|
190 |
+
return ", ".join(outputs)
|
191 |
+
|
192 |
+
|
193 |
+
def tags_to_ja(itag, dict):
|
194 |
+
def t_to_j(match, dict):
|
195 |
+
tag = match.group(0)
|
196 |
+
ja = dict.get(tag.strip().replace("_", " "), "")
|
197 |
+
if ja:
|
198 |
+
return ja
|
199 |
+
else:
|
200 |
+
return tag
|
201 |
+
|
202 |
+
import re
|
203 |
+
tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
|
204 |
+
|
205 |
+
return tag
|
206 |
+
|
207 |
+
|
208 |
+
def convert_tags_to_ja(input_prompt: str = ""):
|
209 |
+
tags = input_prompt.split(",") if input_prompt else []
|
210 |
+
out_tags = []
|
211 |
+
|
212 |
+
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
213 |
+
dict = tags_to_ja_dict
|
214 |
+
for tag in tags:
|
215 |
+
tag = tag.strip().replace("_", " ")
|
216 |
+
tag = tags_to_ja(tag, dict)
|
217 |
+
out_tags.append(tag)
|
218 |
+
|
219 |
+
return ", ".join(out_tags)
|
220 |
+
|
221 |
+
|
222 |
+
def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
|
223 |
+
def to_list(s):
|
224 |
+
return [x.strip() for x in s.split(",") if not s == ""]
|
225 |
+
|
226 |
+
def list_sub(a, b):
|
227 |
+
return [e for e in a if e not in b]
|
228 |
+
|
229 |
+
def list_uniq(l):
|
230 |
+
return sorted(set(l), key=l.index)
|
231 |
+
|
232 |
+
animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
|
233 |
+
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
|
234 |
+
pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
|
235 |
+
pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
|
236 |
+
prompts = to_list(prompt)
|
237 |
+
neg_prompts = to_list(neg_prompt)
|
238 |
+
|
239 |
+
prompts = list_sub(prompts, animagine_ps + pony_ps)
|
240 |
+
neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
|
241 |
+
|
242 |
+
last_empty_p = [""] if not prompts and type != "None" else []
|
243 |
+
last_empty_np = [""] if not neg_prompts and type != "None" else []
|
244 |
+
|
245 |
+
if type == "Animagine":
|
246 |
+
prompts = prompts + animagine_ps
|
247 |
+
neg_prompts = neg_prompts + animagine_nps
|
248 |
+
elif type == "Pony":
|
249 |
+
prompts = prompts + pony_ps
|
250 |
+
neg_prompts = neg_prompts + pony_nps
|
251 |
+
|
252 |
+
prompt = ", ".join(list_uniq(prompts) + last_empty_p)
|
253 |
+
neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
|
254 |
+
|
255 |
+
return prompt, neg_prompt
|
256 |
+
|
257 |
+
|
258 |
+
tag_group_dict = load_dict_from_csv('tag_group.csv')
|
259 |
+
|
260 |
+
|
261 |
+
def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
262 |
+
def is_dressed(tag):
|
263 |
+
import re
|
264 |
+
p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
|
265 |
+
return p.search(tag)
|
266 |
+
|
267 |
+
def is_background(tag):
|
268 |
+
import re
|
269 |
+
p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
|
270 |
+
return p.search(tag)
|
271 |
+
|
272 |
+
un_tags = ['solo']
|
273 |
+
group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
|
274 |
+
keep_group_dict = {
|
275 |
+
"body": ['groups', 'body_parts'],
|
276 |
+
"dress": ['groups', 'body_parts', 'attire'],
|
277 |
+
"all": group_list,
|
278 |
+
}
|
279 |
+
|
280 |
+
def is_necessary(tag, keep_tags, group_dict):
|
281 |
+
if keep_tags == "all":
|
282 |
+
return True
|
283 |
+
elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
|
284 |
+
return False
|
285 |
+
elif keep_tags == "body" and is_dressed(tag):
|
286 |
+
return False
|
287 |
+
elif is_background(tag):
|
288 |
+
return False
|
289 |
+
else:
|
290 |
+
return True
|
291 |
+
|
292 |
+
if keep_tags == "all": return input_prompt
|
293 |
+
keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
|
294 |
+
explicit_group = list(set(group_list) ^ set(keep_group))
|
295 |
+
|
296 |
+
tags = input_prompt.split(",") if input_prompt else []
|
297 |
+
people_tags: list[str] = []
|
298 |
+
other_tags: list[str] = []
|
299 |
+
|
300 |
+
group_dict = tag_group_dict
|
301 |
+
for tag in tags:
|
302 |
+
tag = tag.strip().replace("_", " ")
|
303 |
+
if tag in PEOPLE_TAGS:
|
304 |
+
people_tags.append(tag)
|
305 |
+
elif is_necessary(tag, keep_tags, group_dict):
|
306 |
+
other_tags.append(tag)
|
307 |
+
|
308 |
+
output_prompt = ", ".join(people_tags + other_tags)
|
309 |
+
|
310 |
+
return output_prompt
|
311 |
+
|
312 |
+
|
313 |
+
def sort_taglist(tags: list[str]):
|
314 |
+
if not tags: return []
|
315 |
+
character_tags: list[str] = []
|
316 |
+
series_tags: list[str] = []
|
317 |
+
people_tags: list[str] = []
|
318 |
+
group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
|
319 |
+
group_tags = {}
|
320 |
+
other_tags: list[str] = []
|
321 |
+
rating_tags: list[str] = []
|
322 |
+
|
323 |
+
group_dict = tag_group_dict
|
324 |
+
group_set = set(group_dict.keys())
|
325 |
+
character_set = set(anime_series_dict.keys())
|
326 |
+
series_set = set(anime_series_dict.values())
|
327 |
+
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
328 |
+
|
329 |
+
for tag in tags:
|
330 |
+
tag = tag.strip().replace("_", " ")
|
331 |
+
if tag in PEOPLE_TAGS:
|
332 |
+
people_tags.append(tag)
|
333 |
+
elif tag in rating_set:
|
334 |
+
rating_tags.append(tag)
|
335 |
+
elif tag in group_set:
|
336 |
+
elem = group_dict[tag]
|
337 |
+
group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
|
338 |
+
elif tag in character_set:
|
339 |
+
character_tags.append(tag)
|
340 |
+
elif tag in series_set:
|
341 |
+
series_tags.append(tag)
|
342 |
+
else:
|
343 |
+
other_tags.append(tag)
|
344 |
+
|
345 |
+
output_group_tags: list[str] = []
|
346 |
+
for k in group_list:
|
347 |
+
output_group_tags.extend(group_tags.get(k, []))
|
348 |
+
|
349 |
+
rating_tags = [rating_tags[0]] if rating_tags else []
|
350 |
+
rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
|
351 |
+
|
352 |
+
output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
|
353 |
+
|
354 |
+
return output_tags
|
355 |
+
|
356 |
+
|
357 |
+
def sort_tags(tags: str):
|
358 |
+
if not tags: return ""
|
359 |
+
taglist: list[str] = []
|
360 |
+
for tag in tags.split(","):
|
361 |
+
taglist.append(tag.strip())
|
362 |
+
taglist = list(filter(lambda x: x != "", taglist))
|
363 |
+
return ", ".join(sort_taglist(taglist))
|
364 |
+
|
365 |
+
|
366 |
+
def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
|
367 |
+
results = {
|
368 |
+
k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
|
369 |
+
}
|
370 |
+
|
371 |
+
rating = {}
|
372 |
+
character = {}
|
373 |
+
general = {}
|
374 |
+
|
375 |
+
for k, v in results.items():
|
376 |
+
if k.startswith("rating:"):
|
377 |
+
rating[k.replace("rating:", "")] = v
|
378 |
+
continue
|
379 |
+
elif k.startswith("character:"):
|
380 |
+
character[k.replace("character:", "")] = v
|
381 |
+
continue
|
382 |
+
|
383 |
+
general[k] = v
|
384 |
+
|
385 |
+
character = {k: v for k, v in character.items() if v >= character_threshold}
|
386 |
+
general = {k: v for k, v in general.items() if v >= general_threshold}
|
387 |
+
|
388 |
+
return rating, character, general
|
389 |
+
|
390 |
+
|
391 |
+
def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
392 |
+
people_tags: list[str] = []
|
393 |
+
other_tags: list[str] = []
|
394 |
+
rating_tag = RATING_MAP[rating[0]]
|
395 |
+
|
396 |
+
for tag in general:
|
397 |
+
if tag in PEOPLE_TAGS:
|
398 |
+
people_tags.append(tag)
|
399 |
+
else:
|
400 |
+
other_tags.append(tag)
|
401 |
+
|
402 |
+
all_tags = people_tags + other_tags
|
403 |
+
|
404 |
+
return ", ".join(all_tags)
|
405 |
+
|
406 |
+
|
407 |
+
@spaces.GPU()
|
408 |
+
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
409 |
+
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
410 |
+
|
411 |
+
outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
|
412 |
+
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
413 |
+
|
414 |
+
# get probabilities
|
415 |
+
results = {
|
416 |
+
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
417 |
+
}
|
418 |
+
|
419 |
+
# rating, character, general
|
420 |
+
rating, character, general = postprocess_results(
|
421 |
+
results, general_threshold, character_threshold
|
422 |
+
)
|
423 |
+
|
424 |
+
prompt = gen_prompt(
|
425 |
+
list(rating.keys()), list(character.keys()), list(general.keys())
|
426 |
+
)
|
427 |
+
|
428 |
+
output_series_tag = ""
|
429 |
+
output_series_list = character_list_to_series_list(character.keys())
|
430 |
+
if output_series_list:
|
431 |
+
output_series_tag = output_series_list[0]
|
432 |
+
else:
|
433 |
+
output_series_tag = ""
|
434 |
+
|
435 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
|
436 |
+
|
437 |
+
|
438 |
+
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
|
439 |
+
if algo and not "Use WD Tagger" in algo:
|
440 |
+
return "", "", input_tags, gr.update(interactive=True),
|
441 |
+
return predict_tags(image, general_threshold, character_threshold)
|
442 |
+
|
443 |
+
|
444 |
+
def compose_prompt_to_copy(character: str, series: str, general: str):
|
445 |
+
characters = character.split(",") if character else []
|
446 |
+
serieses = series.split(",") if series else []
|
447 |
+
generals = general.split(",") if general else []
|
448 |
+
tags = characters + serieses + generals
|
449 |
+
cprompt = ",".join(tags) if tags else ""
|
450 |
+
return cprompt
|
tags.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
|
3 |
+
|
4 |
+
|
5 |
+
V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
|
6 |
+
"ultra_wide",
|
7 |
+
"wide",
|
8 |
+
"square",
|
9 |
+
"tall",
|
10 |
+
"ultra_tall",
|
11 |
+
]
|
12 |
+
V2_RATING_OPTIONS: list[RatingTag] = [
|
13 |
+
"sfw",
|
14 |
+
"general",
|
15 |
+
"sensitive",
|
16 |
+
"nsfw",
|
17 |
+
"questionable",
|
18 |
+
"explicit",
|
19 |
+
]
|
20 |
+
V2_LENGTH_OPTIONS: list[LengthTag] = [
|
21 |
+
"very_short",
|
22 |
+
"short",
|
23 |
+
"medium",
|
24 |
+
"long",
|
25 |
+
"very_long",
|
26 |
+
]
|
27 |
+
V2_IDENTITY_OPTIONS: list[IdentityTag] = [
|
28 |
+
"none",
|
29 |
+
"lax",
|
30 |
+
"strict",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
# ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
|
35 |
+
def gradio_copy_text(_text: None):
|
36 |
+
gr.Info("Copied!")
|
37 |
+
|
38 |
+
|
39 |
+
COPY_ACTION_JS = """\
|
40 |
+
(inputs, _outputs) => {
|
41 |
+
// inputs is the string value of the input_text
|
42 |
+
if (inputs.trim() !== "") {
|
43 |
+
navigator.clipboard.writeText(inputs);
|
44 |
+
}
|
45 |
+
}"""
|
v2.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
from dartrs.v2 import (
|
7 |
+
V2Model,
|
8 |
+
MixtralModel,
|
9 |
+
MistralModel,
|
10 |
+
compose_prompt,
|
11 |
+
LengthTag,
|
12 |
+
AspectRatioTag,
|
13 |
+
RatingTag,
|
14 |
+
IdentityTag,
|
15 |
+
)
|
16 |
+
from dartrs.dartrs import DartTokenizer
|
17 |
+
from dartrs.utils import get_generation_config
|
18 |
+
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
from gradio.components import Component
|
22 |
+
|
23 |
+
try:
|
24 |
+
import spaces
|
25 |
+
except ImportError:
|
26 |
+
|
27 |
+
class spaces:
|
28 |
+
def GPU(*args, **kwargs):
|
29 |
+
return lambda x: x
|
30 |
+
|
31 |
+
|
32 |
+
from output import UpsamplingOutput
|
33 |
+
|
34 |
+
|
35 |
+
HF_TOKEN = os.getenv("HF_TOKEN", None)
|
36 |
+
|
37 |
+
V2_ALL_MODELS = {
|
38 |
+
"dart-v2-moe-sft": {
|
39 |
+
"repo": "p1atdev/dart-v2-moe-sft",
|
40 |
+
"type": "sft",
|
41 |
+
"class": MixtralModel,
|
42 |
+
},
|
43 |
+
"dart-v2-sft": {
|
44 |
+
"repo": "p1atdev/dart-v2-sft",
|
45 |
+
"type": "sft",
|
46 |
+
"class": MistralModel,
|
47 |
+
},
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def prepare_models(model_config: dict):
|
52 |
+
model_name = model_config["repo"]
|
53 |
+
tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
|
54 |
+
model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
|
55 |
+
|
56 |
+
return {
|
57 |
+
"tokenizer": tokenizer,
|
58 |
+
"model": model,
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def normalize_tags(tokenizer: DartTokenizer, tags: str):
|
63 |
+
"""Just remove unk tokens."""
|
64 |
+
return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
|
65 |
+
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def generate_tags(
|
69 |
+
model: V2Model,
|
70 |
+
tokenizer: DartTokenizer,
|
71 |
+
prompt: str,
|
72 |
+
ban_token_ids: list[int],
|
73 |
+
):
|
74 |
+
output = model.generate(
|
75 |
+
get_generation_config(
|
76 |
+
prompt,
|
77 |
+
tokenizer=tokenizer,
|
78 |
+
temperature=1,
|
79 |
+
top_p=0.9,
|
80 |
+
top_k=100,
|
81 |
+
max_new_tokens=256,
|
82 |
+
ban_token_ids=ban_token_ids,
|
83 |
+
),
|
84 |
+
)
|
85 |
+
|
86 |
+
return output
|
87 |
+
|
88 |
+
|
89 |
+
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
90 |
+
return (
|
91 |
+
[f"1{noun}"]
|
92 |
+
+ [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
|
93 |
+
+ [f"{maximum+1}+{noun}s"]
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
PEOPLE_TAGS = (
|
98 |
+
_people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def gen_prompt_text(output: UpsamplingOutput):
|
103 |
+
# separate people tags (e.g. 1girl)
|
104 |
+
people_tags = []
|
105 |
+
other_general_tags = []
|
106 |
+
|
107 |
+
for tag in output.general_tags.split(","):
|
108 |
+
tag = tag.strip()
|
109 |
+
if tag in PEOPLE_TAGS:
|
110 |
+
people_tags.append(tag)
|
111 |
+
else:
|
112 |
+
other_general_tags.append(tag)
|
113 |
+
|
114 |
+
return ", ".join(
|
115 |
+
[
|
116 |
+
part.strip()
|
117 |
+
for part in [
|
118 |
+
*people_tags,
|
119 |
+
output.character_tags,
|
120 |
+
output.copyright_tags,
|
121 |
+
*other_general_tags,
|
122 |
+
output.upsampled_tags,
|
123 |
+
output.rating_tag,
|
124 |
+
]
|
125 |
+
if part.strip() != ""
|
126 |
+
]
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def elapsed_time_format(elapsed_time: float) -> str:
|
131 |
+
return f"Elapsed: {elapsed_time:.2f} seconds"
|
132 |
+
|
133 |
+
|
134 |
+
def parse_upsampling_output(
|
135 |
+
upsampler: Callable[..., UpsamplingOutput],
|
136 |
+
):
|
137 |
+
def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
|
138 |
+
output = upsampler(*args)
|
139 |
+
|
140 |
+
return (
|
141 |
+
gen_prompt_text(output),
|
142 |
+
elapsed_time_format(output.elapsed_time),
|
143 |
+
gr.update(interactive=True),
|
144 |
+
gr.update(interactive=True),
|
145 |
+
)
|
146 |
+
|
147 |
+
return _parse_upsampling_output
|
148 |
+
|
149 |
+
|
150 |
+
class V2UI:
|
151 |
+
model_name: str | None = None
|
152 |
+
model: V2Model
|
153 |
+
tokenizer: DartTokenizer
|
154 |
+
|
155 |
+
input_components: list[Component] = []
|
156 |
+
generate_btn: gr.Button
|
157 |
+
|
158 |
+
def on_generate(
|
159 |
+
self,
|
160 |
+
model_name: str,
|
161 |
+
copyright_tags: str,
|
162 |
+
character_tags: str,
|
163 |
+
general_tags: str,
|
164 |
+
rating_tag: RatingTag,
|
165 |
+
aspect_ratio_tag: AspectRatioTag,
|
166 |
+
length_tag: LengthTag,
|
167 |
+
identity_tag: IdentityTag,
|
168 |
+
ban_tags: str,
|
169 |
+
*args,
|
170 |
+
) -> UpsamplingOutput:
|
171 |
+
if self.model_name is None or self.model_name != model_name:
|
172 |
+
models = prepare_models(V2_ALL_MODELS[model_name])
|
173 |
+
self.model = models["model"]
|
174 |
+
self.tokenizer = models["tokenizer"]
|
175 |
+
self.model_name = model_name
|
176 |
+
|
177 |
+
# normalize tags
|
178 |
+
# copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
|
179 |
+
# character_tags = normalize_tags(self.tokenizer, character_tags)
|
180 |
+
# general_tags = normalize_tags(self.tokenizer, general_tags)
|
181 |
+
|
182 |
+
ban_token_ids = self.tokenizer.encode(ban_tags.strip())
|
183 |
+
|
184 |
+
prompt = compose_prompt(
|
185 |
+
prompt=general_tags,
|
186 |
+
copyright=copyright_tags,
|
187 |
+
character=character_tags,
|
188 |
+
rating=rating_tag,
|
189 |
+
aspect_ratio=aspect_ratio_tag,
|
190 |
+
length=length_tag,
|
191 |
+
identity=identity_tag,
|
192 |
+
)
|
193 |
+
|
194 |
+
start = time.time()
|
195 |
+
upsampled_tags = generate_tags(
|
196 |
+
self.model,
|
197 |
+
self.tokenizer,
|
198 |
+
prompt,
|
199 |
+
ban_token_ids,
|
200 |
+
)
|
201 |
+
elapsed_time = time.time() - start
|
202 |
+
|
203 |
+
return UpsamplingOutput(
|
204 |
+
upsampled_tags=upsampled_tags,
|
205 |
+
copyright_tags=copyright_tags,
|
206 |
+
character_tags=character_tags,
|
207 |
+
general_tags=general_tags,
|
208 |
+
rating_tag=rating_tag,
|
209 |
+
aspect_ratio_tag=aspect_ratio_tag,
|
210 |
+
length_tag=length_tag,
|
211 |
+
identity_tag=identity_tag,
|
212 |
+
elapsed_time=elapsed_time,
|
213 |
+
)
|
214 |
+
|