John6666 commited on
Commit
b3dd22d
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
Files changed (16) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +140 -0
  4. character_series_dict.csv +0 -0
  5. danbooru_e621.csv +0 -0
  6. myt2tmod.py +116 -0
  7. originalt2t.py +127 -0
  8. output.py +16 -0
  9. requirements.txt +11 -0
  10. t2t.py +41 -0
  11. t2tmod.py +117 -0
  12. tag_group.csv +0 -0
  13. tagger.py +450 -0
  14. tags.txt +0 -0
  15. utils.py +45 -0
  16. v2.py +214 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Natural Text to Danbooru Tags with Transformer V2
3
+ emoji: 👀📦
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from v2 import (
4
+ V2UI,
5
+ parse_upsampling_output,
6
+ V2_ALL_MODELS,
7
+ )
8
+ from utils import (
9
+ gradio_copy_text,
10
+ COPY_ACTION_JS,
11
+ V2_ASPECT_RATIO_OPTIONS,
12
+ V2_RATING_OPTIONS,
13
+ V2_LENGTH_OPTIONS,
14
+ V2_IDENTITY_OPTIONS
15
+ )
16
+ from tagger import (
17
+ predict_tags,
18
+ convert_danbooru_to_e621_prompt,
19
+ remove_specific_prompt,
20
+ insert_recom_prompt,
21
+ compose_prompt_to_copy,
22
+ translate_prompt,
23
+ )
24
+ from t2t import predict_text_to_tags
25
+
26
+
27
+ def description_ui():
28
+ gr.Markdown(
29
+ """
30
+ ## Natural Text to Danbooru Tags with Danbooru Tags Transformer V2
31
+ Natural text => Prompt => Upsampled longer prompt
32
+ - Mod of [ooferdoodles/text2tags-demo](https://huggingface.co/spaces/ooferdoodles/text2tags-demo) and p1atdev's [Danbooru Tags Transformer V2 Demo](https://huggingface.co/spaces/p1atdev/danbooru-tags-transformer-v2)
33
+ - It's buggy but seems to work for now.
34
+ """
35
+ )
36
+
37
+
38
+ def main():
39
+
40
+ v2 = V2UI()
41
+
42
+ with gr.Blocks() as ui:
43
+ description_ui()
44
+
45
+ with gr.Row():
46
+ with gr.Column(scale=2):
47
+ with gr.Group():
48
+ t2t_text = gr.TextArea(label="Natural text", lines=6, placeholder="Minato Aqua from hololive with pink and blue twintails in a blue maid outfit ...", value="", show_copy_button=True)
49
+ with gr.Accordion(label="Advanced options", open=False):
50
+ translate_t2t_text_button = gr.Button(value="Translate text to English", size="sm", variant="secondary")
51
+ t2t_max_tokens = gr.Slider(0, 256, step=16, value=128, label='max_tokens')
52
+ t2t_temperature = gr.Slider(0.001, 2, step=0.1, value=0.7, label='temperature')
53
+ t2t_top_k = gr.Slider(0, 100, step=5, value=20, label='top_k')
54
+ t2t_top_p = gr.Slider(0, 2, step=0.05, value=0.95, label='top_p')
55
+ t2t_repeat_penalty = gr.Slider(0, 5, step=0.1, value=1.1, label='repeat_penalty')
56
+ t2t_examples = gr.Examples([
57
+ ["Minato Aqua from hololive with pink and blue twintails in a blue maid outfit"],
58
+ ],
59
+ t2t_text,
60
+ cache_examples=False,
61
+ )
62
+ generate_from_text_btn = gr.Button(value="GENERATE TAGS FROM TEXT", size="lg", variant="primary")
63
+
64
+ with gr.Group():
65
+ input_character = gr.Textbox(label="Character tags", placeholder="hatsune miku", visible=False)
66
+ input_copyright = gr.Textbox(label="Copyright tags", placeholder="vocaloid", visible=False)
67
+ input_general = gr.TextArea(label="General tags", lines=6, placeholder="1girl, ...", value="", show_copy_button=True)
68
+ input_tags_to_copy = gr.Textbox(value="", visible=False)
69
+ copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False, visible=False)
70
+ tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
71
+ input_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="explicit")
72
+ with gr.Accordion(label="Advanced options", open=False):
73
+ input_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square")
74
+ input_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="very_long")
75
+ input_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
76
+ input_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
77
+ model_name = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
78
+ dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
79
+ recom_animagine = gr.Textbox(label="Animagine reccomended prompt", value="Animagine", visible=False)
80
+ recom_pony = gr.Textbox(label="Pony reccomended prompt", value="Pony", visible=False)
81
+
82
+ generate_btn = gr.Button(value="GENERATE TAGS", size="lg", variant="primary")
83
+
84
+ with gr.Group():
85
+ output_text = gr.TextArea(label="Output tags", interactive=False, show_copy_button=True)
86
+ copy_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
87
+ elapsed_time_md = gr.Markdown(label="Elapsed time", value="", visible=False)
88
+
89
+ with gr.Group():
90
+ output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
91
+ copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
92
+
93
+ v2.input_components = [
94
+ model_name,
95
+ input_copyright,
96
+ input_character,
97
+ input_general,
98
+ input_rating,
99
+ input_aspect_ratio,
100
+ input_length,
101
+ input_identity,
102
+ input_ban_tags,
103
+ ]
104
+
105
+ translate_t2t_text_button.click(translate_prompt, inputs=[t2t_text], outputs=[t2t_text])
106
+
107
+ generate_from_text_btn.click(
108
+ predict_text_to_tags,
109
+ inputs=[t2t_text, t2t_max_tokens, t2t_temperature, t2t_top_k, t2t_top_p, t2t_repeat_penalty],
110
+ outputs=[
111
+ input_general,
112
+ ],
113
+ )
114
+
115
+ copy_input_btn.click(compose_prompt_to_copy, inputs=[input_character, input_copyright, input_general], outputs=[input_tags_to_copy]).then(
116
+ gradio_copy_text, inputs=[input_tags_to_copy], js=COPY_ACTION_JS,
117
+ )
118
+
119
+ generate_btn.click(
120
+ parse_upsampling_output(v2.on_generate),
121
+ inputs=[
122
+ *v2.input_components,
123
+ ],
124
+ outputs=[output_text, elapsed_time_md, copy_btn, copy_btn_pony],
125
+ ).then(
126
+ convert_danbooru_to_e621_prompt, inputs=[output_text, tag_type], outputs=[output_text_pony],
127
+ ).then(
128
+ insert_recom_prompt, inputs=[output_text, dummy_np, recom_animagine], outputs=[output_text, dummy_np],
129
+ ).then(
130
+ insert_recom_prompt, inputs=[output_text_pony, dummy_np, recom_pony], outputs=[output_text_pony, dummy_np],
131
+ )
132
+ copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
133
+ copy_btn_pony.click(gradio_copy_text, inputs=[output_text_pony], js=COPY_ACTION_JS)
134
+
135
+ ui.launch()
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
140
+
character_series_dict.csv ADDED
The diff for this file is too large to render. See raw diff
 
danbooru_e621.csv ADDED
The diff for this file is too large to render. See raw diff
 
myt2tmod.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ import re
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import editdistance
7
+
8
+
9
+ class TaggerLlama:
10
+ MODEL_URL = "John6666/llama-tagger-HF-GPTQ-4bits"
11
+ SAVE_NAME = "llama-tagger-HF-GPTQ-4bits"
12
+ TAGS_FILE_NAME = "tags.txt"
13
+ model = None
14
+ tokenizer = None
15
+
16
+ def __init__(
17
+ self,
18
+ ):
19
+ self.download_model()
20
+ self.tag_list = self.load_tags()
21
+
22
+ def download_model(
23
+ self,
24
+ model_url=None,
25
+ save_name=None,
26
+ ):
27
+ model_url = model_url or self.MODEL_URL # Use self.MODEL_URL
28
+ save_name = save_name or self.SAVE_NAME
29
+
30
+ save_path = os.path.join(tempfile.gettempdir(), save_name)
31
+ if os.path.exists(save_path):
32
+ print("Model already exists. Skipping download.")
33
+ return
34
+ print("Downloading Model")
35
+ self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_URL)
36
+ self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_URL, device_map="cuda:0")
37
+ self.tokenizer.save_pretrained(self.SAVE_NAME)
38
+ self.model.save_pretrained(self.SAVE_NAME)
39
+ print("Model Downloaded")
40
+
41
+ def load_tags(self):
42
+ module_path = os.path.abspath(__file__)
43
+ lookups_dir = os.path.join(os.path.dirname(module_path), "tags.txt")
44
+ try:
45
+ tags_file = lookups_dir
46
+ with open(tags_file, "r") as f:
47
+ tag_dict = [line.strip() for line in f]
48
+ return tag_dict
49
+ except IOError as e:
50
+ print(f"Error loading tag dictionary: {e}")
51
+ return []
52
+
53
+ def preprocess_tag(self, tag):
54
+ tag = tag.lower()
55
+ match = re.match(r"^([^()]*\([^()]*\))\s*.*$", tag)
56
+ return match.group(1) if match else tag
57
+
58
+ def find_closest_tag(self, tag, threshold, tag_list, cache={}):
59
+ if tag in cache:
60
+ return cache[tag]
61
+
62
+ closest_tag = min(tag_list, key=lambda x: editdistance.eval(tag, x))
63
+ if editdistance.eval(tag, closest_tag) <= threshold:
64
+ cache[tag] = closest_tag
65
+ return closest_tag
66
+ else:
67
+ return None
68
+
69
+ def correct_tags(self, tags, tag_list, preprocess=True):
70
+ if preprocess:
71
+ tags = (self.preprocess_tag(x) for x in tags)
72
+ corrected_tags = set()
73
+ for tag in tags:
74
+ threshold = max(1, len(tag) - 10)
75
+ closest_tag = self.find_closest_tag(tag, threshold, tag_list)
76
+ if closest_tag:
77
+ corrected_tags.add(closest_tag)
78
+ return sorted(list(corrected_tags))
79
+
80
+ def predict_tags(
81
+ self,
82
+ prompt: str,
83
+ max_tokens: int = 128,
84
+ temperature: float = 0.8,
85
+ top_p: float = 0.95,
86
+ repeat_penalty: float = 1.1,
87
+ top_k: int = 40,
88
+ ):
89
+ prompt = f"### Caption:{prompt}\n### Tags:"
90
+
91
+ input_ids = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
92
+ terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("\n"),
93
+ self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("### Tags:")]
94
+
95
+ raw_output = self.model.generate(
96
+ input_ids.to(self.model.device),
97
+ tokenizer=self.tokenizer,
98
+ max_new_tokens=max_tokens,
99
+ temperature=temperature,
100
+ top_p=top_p,
101
+ repetition_penalty=repeat_penalty,
102
+ top_k=top_k,
103
+ do_sample=True,
104
+ stop_strings=["\n", "### Tags:"],
105
+ eos_token_id=terminators,
106
+ pad_token_id=self.tokenizer.eos_token_id,
107
+ )
108
+ output = self.tokenizer.batch_decode(raw_output, skip_special_tokens=True)
109
+
110
+ raw_preds = re.sub('^.+\n### Tags:(.+?$)', '\\1', output[0])
111
+ pred_tags = [x.strip() for x in raw_preds.split(",")]
112
+ corrected_tags = self.correct_tags(pred_tags, self.tag_list)
113
+ return corrected_tags
114
+
115
+ # https://github.com/ooferdoodles1337/text2tags-lib
116
+ # https://huggingface.co/docs/transformers/main_classes/text_generation
originalt2t.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import tempfile
3
+ import os
4
+ import re
5
+
6
+ import wget
7
+ import editdistance
8
+ from llama_cpp import Llama
9
+
10
+
11
+ class TaggerLlama(Llama):
12
+ MODEL_URL = "https://huggingface.co/ooferdoodles/llama-tagger-7b/resolve/main/llama-tagger.gguf?download=true"
13
+ SAVE_NAME = "llama-tagger.gguf"
14
+ TAGS_FILE_NAME = "tags.txt"
15
+
16
+
17
+ def __init__(
18
+ self,
19
+ model_path: str = None,
20
+ **kwargs,
21
+ ):
22
+ if model_path is None:
23
+ model_path = os.path.join(tempfile.gettempdir(), self.SAVE_NAME)
24
+ self.download_model()
25
+ super().__init__(model_path, **kwargs)
26
+ self.tag_list = self.load_tags()
27
+
28
+ def download_model(
29
+ self,
30
+ model_url=None,
31
+ save_name=None,
32
+ ):
33
+ model_url = model_url or self.MODEL_URL # Use self.MODEL_URL
34
+ save_name = save_name or self.SAVE_NAME
35
+
36
+ save_path = os.path.join(tempfile.gettempdir(), save_name)
37
+ if os.path.exists(save_path):
38
+ print("Model already exists. Skipping download.")
39
+ return
40
+ print("Downloading Model")
41
+ wget.download(model_url, out=save_path)
42
+ print("Model Downloaded")
43
+
44
+ def load_tags(self):
45
+ module_path = os.path.abspath(__file__)
46
+ lookups_dir = os.path.join(os.path.dirname(module_path), "tags.txt")
47
+ try:
48
+ tags_file = lookups_dir
49
+ with open(tags_file, "r") as f:
50
+ tag_dict = [line.strip() for line in f]
51
+ return tag_dict
52
+ except IOError as e:
53
+ print(f"Error loading tag dictionary: {e}")
54
+ return []
55
+
56
+ def preprocess_tag(self, tag):
57
+ tag = tag.lower()
58
+ match = re.match(r"^([^()]*\([^()]*\))\s*.*$", tag)
59
+ return match.group(1) if match else tag
60
+
61
+ def find_closest_tag(self, tag, threshold, tag_list, cache={}):
62
+ if tag in cache:
63
+ return cache[tag]
64
+
65
+ closest_tag = min(tag_list, key=lambda x: editdistance.eval(tag, x))
66
+ if editdistance.eval(tag, closest_tag) <= threshold:
67
+ cache[tag] = closest_tag
68
+ return closest_tag
69
+ else:
70
+ return None
71
+
72
+ def correct_tags(self, tags, tag_list, preprocess=True):
73
+ if preprocess:
74
+ tags = (self.preprocess_tag(x) for x in tags)
75
+ corrected_tags = set()
76
+ for tag in tags:
77
+ threshold = max(1, len(tag) - 10)
78
+ closest_tag = self.find_closest_tag(tag, threshold, tag_list)
79
+ if closest_tag:
80
+ corrected_tags.add(closest_tag)
81
+ return sorted(list(corrected_tags))
82
+
83
+ def predict_tags(
84
+ self,
85
+ prompt: str,
86
+ suffix: Optional[str] = None,
87
+ max_tokens: int = 128,
88
+ temperature: float = 0.8,
89
+ top_p: float = 0.95,
90
+ logprobs: Optional[int] = None,
91
+ echo: bool = False,
92
+ stop: Optional[List[str]] = ["/n", "### Tags:"],
93
+ frequency_penalty: float = 0.0,
94
+ presence_penalty: float = 0.0,
95
+ repeat_penalty: float = 1.1,
96
+ top_k: int = 40,
97
+ stream: bool = False,
98
+ tfs_z: float = 1.0,
99
+ mirostat_mode: int = 0,
100
+ mirostat_tau: float = 5.0,
101
+ mirostat_eta: float = 0.1,
102
+ ):
103
+ prompt = f"### Caption: {prompt}\n### Tags: "
104
+
105
+ output = self.create_completion(
106
+ prompt=prompt,
107
+ suffix=suffix,
108
+ max_tokens=max_tokens,
109
+ temperature=temperature,
110
+ top_p=top_p,
111
+ logprobs=logprobs,
112
+ echo=echo,
113
+ stop=stop,
114
+ frequency_penalty=frequency_penalty,
115
+ presence_penalty=presence_penalty,
116
+ repeat_penalty=repeat_penalty,
117
+ top_k=top_k,
118
+ stream=stream,
119
+ tfs_z=tfs_z,
120
+ mirostat_mode=mirostat_mode,
121
+ mirostat_tau=mirostat_tau,
122
+ mirostat_eta=mirostat_eta,
123
+ )
124
+ raw_preds = output["choices"][0]["text"]
125
+ pred_tags = [x.strip() for x in raw_preds.split(",")]
126
+ corrected_tags = self.correct_tags(pred_tags, self.tag_list)
127
+ return corrected_tags
output.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class UpsamplingOutput:
6
+ upsampled_tags: str
7
+
8
+ copyright_tags: str
9
+ character_tags: str
10
+ general_tags: str
11
+ rating_tag: str
12
+ aspect_ratio_tag: str
13
+ length_tag: str
14
+ identity_tag: str
15
+
16
+ elapsed_time: float = 0.0
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ editdistance
2
+ transformers
3
+ accelerate
4
+ sentencepiece
5
+ auto-gptq
6
+ optimum
7
+ httpx==0.13.3
8
+ httpcore
9
+ googletrans==4.0.0rc1
10
+ optimum[onnxruntime]
11
+ dartrs
t2t.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from t2tmod import TaggerLlama
2
+ import spaces
3
+
4
+
5
+ def translate_text(text = ""):
6
+ def translate_to_english(prompt):
7
+ import httpcore
8
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
9
+ from googletrans import Translator
10
+ translator = Translator()
11
+ try:
12
+ translated_text = translator.translate(text, src='auto', dest='en').text
13
+ return translated_text
14
+ except Exception as e:
15
+ return text
16
+
17
+ def is_japanese(s):
18
+ import unicodedata
19
+ for ch in s:
20
+ name = unicodedata.name(ch, "")
21
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
22
+ return True
23
+ return False
24
+
25
+ return translate_to_english(text) if is_japanese(text) else text
26
+
27
+
28
+ t2t_model = TaggerLlama()
29
+
30
+ @spaces.GPU()
31
+ def predict_text_to_tags(input_text: str="", max_tokens: int=128, temperature: float=0.8, top_k: int=40, top_p: float=0.95, repeat_penalty: float=1.1):
32
+ text = translate_text(input_text)
33
+ tags = t2t_model.predict_tags(text, max_tokens=max_tokens, temperature=temperature,
34
+ top_k=top_k, top_p=top_p, repeat_penalty=repeat_penalty)
35
+
36
+ if text != input_text:
37
+ output = text + ', ' + ', '.join(tags).replace("_", " ")
38
+ else:
39
+ output = ', '.join(tags).replace("_", " ")
40
+
41
+ return output
t2tmod.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ import re
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ import editdistance
7
+
8
+
9
+ class TaggerLlama:
10
+ MODEL_URL = "John6666/llama-tagger-HF-GPTQ-4bits"
11
+ SAVE_NAME = "llama-tagger-HF-GPTQ-4bits"
12
+ TAGS_FILE_NAME = "tags.txt"
13
+ model = None
14
+ tokenizer = None
15
+
16
+ def __init__(
17
+ self,
18
+ ):
19
+ self.download_model()
20
+ self.tag_list = self.load_tags()
21
+
22
+ def download_model(
23
+ self,
24
+ model_url=None,
25
+ save_name=None,
26
+ ):
27
+ model_url = model_url or self.MODEL_URL # Use self.MODEL_URL
28
+ save_name = save_name or self.SAVE_NAME
29
+
30
+ save_path = os.path.join(tempfile.gettempdir(), save_name)
31
+ if os.path.exists(save_path):
32
+ print("Model already exists. Skipping download.")
33
+ return
34
+ print("Downloading Model")
35
+ self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_URL)
36
+ self.model = AutoModelForCausalLM.from_pretrained(self.MODEL_URL, device_map="cuda:0")
37
+ self.tokenizer.save_pretrained(self.SAVE_NAME)
38
+ self.model.save_pretrained(self.SAVE_NAME)
39
+ print("Model Downloaded")
40
+
41
+ def load_tags(self):
42
+ module_path = os.path.abspath(__file__)
43
+ lookups_dir = os.path.join(os.path.dirname(module_path), "tags.txt")
44
+ try:
45
+ tags_file = lookups_dir
46
+ with open(tags_file, "r") as f:
47
+ tag_dict = [line.strip() for line in f]
48
+ return tag_dict
49
+ except IOError as e:
50
+ print(f"Error loading tag dictionary: {e}")
51
+ return []
52
+
53
+ def preprocess_tag(self, tag):
54
+ tag = tag.lower()
55
+ match = re.match(r"^([^()]*\([^()]*\))\s*.*$", tag)
56
+ return match.group(1) if match else tag
57
+
58
+ def find_closest_tag(self, tag, threshold, tag_list, cache={}):
59
+ if tag in cache:
60
+ return cache[tag]
61
+
62
+ closest_tag = min(tag_list, key=lambda x: editdistance.eval(tag, x))
63
+ if editdistance.eval(tag, closest_tag) <= threshold:
64
+ cache[tag] = closest_tag
65
+ return closest_tag
66
+ else:
67
+ return None
68
+
69
+ def correct_tags(self, tags, tag_list, preprocess=True):
70
+ if preprocess:
71
+ tags = (self.preprocess_tag(x) for x in tags)
72
+ corrected_tags = set()
73
+ for tag in tags:
74
+ threshold = max(1, len(tag) - 10)
75
+ closest_tag = self.find_closest_tag(tag, threshold, tag_list)
76
+ if closest_tag:
77
+ corrected_tags.add(closest_tag)
78
+ return sorted(list(corrected_tags))
79
+
80
+ def predict_tags(
81
+ self,
82
+ prompt: str,
83
+ max_tokens: int = 128,
84
+ temperature: float = 0.8,
85
+ top_p: float = 0.95,
86
+ repeat_penalty: float = 1.1,
87
+ top_k: int = 40,
88
+ ):
89
+ prompt = f"### Caption:{prompt.strip()}\n### Tags:"
90
+
91
+ input_ids = self.tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
92
+ terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("\n"),
93
+ self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("### Tags:")]
94
+
95
+ raw_output = self.model.generate(
96
+ input_ids.to(self.model.device),
97
+ tokenizer=self.tokenizer,
98
+ max_new_tokens=max_tokens,
99
+ temperature=temperature,
100
+ top_p=top_p,
101
+ repetition_penalty=repeat_penalty,
102
+ top_k=top_k,
103
+ do_sample=True,
104
+ stop_strings=["\n", "### Tags:"],
105
+ eos_token_id=terminators,
106
+ pad_token_id=self.tokenizer.eos_token_id,
107
+ )
108
+ output = self.tokenizer.batch_decode(raw_output, skip_special_tokens=True)
109
+
110
+ raw_preds = re.sub('^.+\n### Tags:(.+?$)', '\\1', output[0])
111
+ raw_preds = ",".join(raw_preds.split(",")[2:-1]) if raw_preds.split(",")[0].strip() == "1boy" else raw_preds # to avoid a mysterious bug
112
+ pred_tags = [x.strip() for x in raw_preds.split(",")]
113
+ corrected_tags = self.correct_tags(pred_tags, self.tag_list)
114
+ return corrected_tags
115
+
116
+ # https://github.com/ooferdoodles1337/text2tags-lib
117
+ # https://huggingface.co/docs/transformers/main_classes/text_generation
tag_group.csv ADDED
The diff for this file is too large to render. See raw diff
 
tagger.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import gradio as gr
4
+ import spaces # ZERO GPU
5
+
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoModelForImageClassification,
9
+ )
10
+
11
+ WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
12
+ WD_MODEL_NAME = WD_MODEL_NAMES[0]
13
+
14
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
15
+ wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
16
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
17
+
18
+
19
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
20
+ return (
21
+ [f"1{noun}"]
22
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
23
+ + [f"{maximum+1}+{noun}s"]
24
+ )
25
+
26
+
27
+ PEOPLE_TAGS = (
28
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
29
+ )
30
+
31
+
32
+ RATING_MAP = {
33
+ "general": "safe",
34
+ "sensitive": "sensitive",
35
+ "questionable": "nsfw",
36
+ "explicit": "explicit, nsfw",
37
+ }
38
+ DANBOORU_TO_E621_RATING_MAP = {
39
+ "safe": "rating_safe",
40
+ "sensitive": "rating_safe",
41
+ "nsfw": "rating_explicit",
42
+ "explicit, nsfw": "rating_explicit",
43
+ "explicit": "rating_explicit",
44
+ "rating:safe": "rating_safe",
45
+ "rating:general": "rating_safe",
46
+ "rating:sensitive": "rating_safe",
47
+ "rating:questionable, nsfw": "rating_explicit",
48
+ "rating:explicit, nsfw": "rating_explicit",
49
+ }
50
+
51
+
52
+ def load_dict_from_csv(filename):
53
+ with open(filename, 'r', encoding="utf-8") as f:
54
+ lines = f.readlines()
55
+ dict = {}
56
+ for line in lines:
57
+ parts = line.strip().split(',')
58
+ dict[parts[0]] = parts[1]
59
+ return dict
60
+
61
+
62
+ anime_series_dict = load_dict_from_csv('character_series_dict.csv')
63
+
64
+
65
+ def character_list_to_series_list(character_list):
66
+ output_series_tag = []
67
+ series_tag = ""
68
+ series_dict = anime_series_dict
69
+ for tag in character_list:
70
+ series_tag = series_dict.get(tag, "")
71
+ if tag.endswith(")"):
72
+ tags = tag.split("(")
73
+ character_tag = "(".join(tags[:-1])
74
+ if character_tag.endswith(" "):
75
+ character_tag = character_tag[:-1]
76
+ series_tag = tags[-1].replace(")", "")
77
+
78
+ if series_tag:
79
+ output_series_tag.append(series_tag)
80
+
81
+ return output_series_tag
82
+
83
+
84
+ def danbooru_to_e621(dtag, e621_dict):
85
+ def d_to_e(match, e621_dict):
86
+ dtag = match.group(0)
87
+ etag = e621_dict.get(dtag.strip().replace("_", " "), "")
88
+ if etag:
89
+ return etag
90
+ else:
91
+ return dtag
92
+
93
+ import re
94
+ tag = re.sub(r'[\w ]+', lambda wrapper: d_to_e(wrapper, e621_dict), dtag, 2)
95
+
96
+ return tag
97
+
98
+
99
+ danbooru_to_e621_dict = load_dict_from_csv('danbooru_e621.csv')
100
+
101
+
102
+ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "danbooru"):
103
+ if prompt_type == "danbooru": return input_prompt
104
+ tags = input_prompt.split(",") if input_prompt else []
105
+ people_tags: list[str] = []
106
+ other_tags: list[str] = []
107
+ rating_tags: list[str] = []
108
+
109
+ e621_dict = danbooru_to_e621_dict
110
+ for tag in tags:
111
+ tag = tag.strip().replace("_", " ")
112
+ tag = danbooru_to_e621(tag, e621_dict)
113
+ if tag in PEOPLE_TAGS:
114
+ people_tags.append(tag)
115
+ elif tag in DANBOORU_TO_E621_RATING_MAP.keys():
116
+ rating_tags.append(DANBOORU_TO_E621_RATING_MAP.get(tag.replace(" ",""), ""))
117
+ else:
118
+ other_tags.append(tag)
119
+
120
+ rating_tags = sorted(set(rating_tags), key=rating_tags.index)
121
+ rating_tags = [rating_tags[0]] if rating_tags else []
122
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
123
+
124
+ output_prompt = ", ".join(people_tags + other_tags + rating_tags)
125
+
126
+ return output_prompt
127
+
128
+
129
+ def translate_prompt(prompt: str = ""):
130
+ def translate_to_english(prompt):
131
+ import httpcore
132
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
133
+ from googletrans import Translator
134
+ translator = Translator()
135
+ try:
136
+ translated_prompt = translator.translate(prompt, src='auto', dest='en').text
137
+ return translated_prompt
138
+ except Exception as e:
139
+ return prompt
140
+
141
+ def is_japanese(s):
142
+ import unicodedata
143
+ for ch in s:
144
+ name = unicodedata.name(ch, "")
145
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
146
+ return True
147
+ return False
148
+
149
+ def to_list(s):
150
+ return [x.strip() for x in s.split(",")]
151
+
152
+ prompts = to_list(prompt)
153
+ outputs = []
154
+ for p in prompts:
155
+ p = translate_to_english(p) if is_japanese(p) else p
156
+ outputs.append(p)
157
+
158
+ return ", ".join(outputs)
159
+
160
+
161
+ def translate_prompt_to_ja(prompt: str = ""):
162
+ def translate_to_japanese(prompt):
163
+ import httpcore
164
+ setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
165
+ from googletrans import Translator
166
+ translator = Translator()
167
+ try:
168
+ translated_prompt = translator.translate(prompt, src='en', dest='ja').text
169
+ return translated_prompt
170
+ except Exception as e:
171
+ return prompt
172
+
173
+ def is_japanese(s):
174
+ import unicodedata
175
+ for ch in s:
176
+ name = unicodedata.name(ch, "")
177
+ if "CJK UNIFIED" in name or "HIRAGANA" in name or "KATAKANA" in name:
178
+ return True
179
+ return False
180
+
181
+ def to_list(s):
182
+ return [x.strip() for x in s.split(",")]
183
+
184
+ prompts = to_list(prompt)
185
+ outputs = []
186
+ for p in prompts:
187
+ p = translate_to_japanese(p) if not is_japanese(p) else p
188
+ outputs.append(p)
189
+
190
+ return ", ".join(outputs)
191
+
192
+
193
+ def tags_to_ja(itag, dict):
194
+ def t_to_j(match, dict):
195
+ tag = match.group(0)
196
+ ja = dict.get(tag.strip().replace("_", " "), "")
197
+ if ja:
198
+ return ja
199
+ else:
200
+ return tag
201
+
202
+ import re
203
+ tag = re.sub(r'[\w ]+', lambda wrapper: t_to_j(wrapper, dict), itag, 2)
204
+
205
+ return tag
206
+
207
+
208
+ def convert_tags_to_ja(input_prompt: str = ""):
209
+ tags = input_prompt.split(",") if input_prompt else []
210
+ out_tags = []
211
+
212
+ tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
213
+ dict = tags_to_ja_dict
214
+ for tag in tags:
215
+ tag = tag.strip().replace("_", " ")
216
+ tag = tags_to_ja(tag, dict)
217
+ out_tags.append(tag)
218
+
219
+ return ", ".join(out_tags)
220
+
221
+
222
+ def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
223
+ def to_list(s):
224
+ return [x.strip() for x in s.split(",") if not s == ""]
225
+
226
+ def list_sub(a, b):
227
+ return [e for e in a if e not in b]
228
+
229
+ def list_uniq(l):
230
+ return sorted(set(l), key=l.index)
231
+
232
+ animagine_ps = to_list("anime artwork, anime style, key visual, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
233
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
234
+ pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
235
+ pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
236
+ prompts = to_list(prompt)
237
+ neg_prompts = to_list(neg_prompt)
238
+
239
+ prompts = list_sub(prompts, animagine_ps + pony_ps)
240
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps)
241
+
242
+ last_empty_p = [""] if not prompts and type != "None" else []
243
+ last_empty_np = [""] if not neg_prompts and type != "None" else []
244
+
245
+ if type == "Animagine":
246
+ prompts = prompts + animagine_ps
247
+ neg_prompts = neg_prompts + animagine_nps
248
+ elif type == "Pony":
249
+ prompts = prompts + pony_ps
250
+ neg_prompts = neg_prompts + pony_nps
251
+
252
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
253
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
254
+
255
+ return prompt, neg_prompt
256
+
257
+
258
+ tag_group_dict = load_dict_from_csv('tag_group.csv')
259
+
260
+
261
+ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
262
+ def is_dressed(tag):
263
+ import re
264
+ p = re.compile(r'dress|cloth|uniform|costume|vest|sweater|coat|shirt|jacket|blazer|apron|leotard|hood|sleeve|skirt|shorts|pant|loafer|ribbon|necktie|bow|collar|glove|sock|shoe|boots|wear|emblem')
265
+ return p.search(tag)
266
+
267
+ def is_background(tag):
268
+ import re
269
+ p = re.compile(r'background|outline|light|sky|build|day|screen|tree|city')
270
+ return p.search(tag)
271
+
272
+ un_tags = ['solo']
273
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
274
+ keep_group_dict = {
275
+ "body": ['groups', 'body_parts'],
276
+ "dress": ['groups', 'body_parts', 'attire'],
277
+ "all": group_list,
278
+ }
279
+
280
+ def is_necessary(tag, keep_tags, group_dict):
281
+ if keep_tags == "all":
282
+ return True
283
+ elif tag in un_tags or group_dict.get(tag, "") in explicit_group:
284
+ return False
285
+ elif keep_tags == "body" and is_dressed(tag):
286
+ return False
287
+ elif is_background(tag):
288
+ return False
289
+ else:
290
+ return True
291
+
292
+ if keep_tags == "all": return input_prompt
293
+ keep_group = keep_group_dict.get(keep_tags, keep_group_dict["body"])
294
+ explicit_group = list(set(group_list) ^ set(keep_group))
295
+
296
+ tags = input_prompt.split(",") if input_prompt else []
297
+ people_tags: list[str] = []
298
+ other_tags: list[str] = []
299
+
300
+ group_dict = tag_group_dict
301
+ for tag in tags:
302
+ tag = tag.strip().replace("_", " ")
303
+ if tag in PEOPLE_TAGS:
304
+ people_tags.append(tag)
305
+ elif is_necessary(tag, keep_tags, group_dict):
306
+ other_tags.append(tag)
307
+
308
+ output_prompt = ", ".join(people_tags + other_tags)
309
+
310
+ return output_prompt
311
+
312
+
313
+ def sort_taglist(tags: list[str]):
314
+ if not tags: return []
315
+ character_tags: list[str] = []
316
+ series_tags: list[str] = []
317
+ people_tags: list[str] = []
318
+ group_list = ['groups', 'body_parts', 'attire', 'posture', 'objects', 'creatures', 'locations', 'disambiguation_pages', 'commonly_misused_tags', 'phrases', 'verbs_and_gerunds', 'subjective', 'nudity', 'sex_objects', 'sex', 'sex_acts', 'image_composition', 'artistic_license', 'text', 'year_tags', 'metatags']
319
+ group_tags = {}
320
+ other_tags: list[str] = []
321
+ rating_tags: list[str] = []
322
+
323
+ group_dict = tag_group_dict
324
+ group_set = set(group_dict.keys())
325
+ character_set = set(anime_series_dict.keys())
326
+ series_set = set(anime_series_dict.values())
327
+ rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
328
+
329
+ for tag in tags:
330
+ tag = tag.strip().replace("_", " ")
331
+ if tag in PEOPLE_TAGS:
332
+ people_tags.append(tag)
333
+ elif tag in rating_set:
334
+ rating_tags.append(tag)
335
+ elif tag in group_set:
336
+ elem = group_dict[tag]
337
+ group_tags[elem] = group_tags[elem] + [tag] if elem in group_tags else [tag]
338
+ elif tag in character_set:
339
+ character_tags.append(tag)
340
+ elif tag in series_set:
341
+ series_tags.append(tag)
342
+ else:
343
+ other_tags.append(tag)
344
+
345
+ output_group_tags: list[str] = []
346
+ for k in group_list:
347
+ output_group_tags.extend(group_tags.get(k, []))
348
+
349
+ rating_tags = [rating_tags[0]] if rating_tags else []
350
+ rating_tags = ["explicit, nsfw"] if rating_tags and rating_tags[0] == "explicit" else rating_tags
351
+
352
+ output_tags = character_tags + series_tags + people_tags + output_group_tags + other_tags + rating_tags
353
+
354
+ return output_tags
355
+
356
+
357
+ def sort_tags(tags: str):
358
+ if not tags: return ""
359
+ taglist: list[str] = []
360
+ for tag in tags.split(","):
361
+ taglist.append(tag.strip())
362
+ taglist = list(filter(lambda x: x != "", taglist))
363
+ return ", ".join(sort_taglist(taglist))
364
+
365
+
366
+ def postprocess_results(results: dict[str, float], general_threshold: float, character_threshold: float):
367
+ results = {
368
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
369
+ }
370
+
371
+ rating = {}
372
+ character = {}
373
+ general = {}
374
+
375
+ for k, v in results.items():
376
+ if k.startswith("rating:"):
377
+ rating[k.replace("rating:", "")] = v
378
+ continue
379
+ elif k.startswith("character:"):
380
+ character[k.replace("character:", "")] = v
381
+ continue
382
+
383
+ general[k] = v
384
+
385
+ character = {k: v for k, v in character.items() if v >= character_threshold}
386
+ general = {k: v for k, v in general.items() if v >= general_threshold}
387
+
388
+ return rating, character, general
389
+
390
+
391
+ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
392
+ people_tags: list[str] = []
393
+ other_tags: list[str] = []
394
+ rating_tag = RATING_MAP[rating[0]]
395
+
396
+ for tag in general:
397
+ if tag in PEOPLE_TAGS:
398
+ people_tags.append(tag)
399
+ else:
400
+ other_tags.append(tag)
401
+
402
+ all_tags = people_tags + other_tags
403
+
404
+ return ", ".join(all_tags)
405
+
406
+
407
+ @spaces.GPU()
408
+ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
409
+ inputs = wd_processor.preprocess(image, return_tensors="pt")
410
+
411
+ outputs = wd_model(**inputs.to(wd_model.device, wd_model.dtype))
412
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
413
+
414
+ # get probabilities
415
+ results = {
416
+ wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
417
+ }
418
+
419
+ # rating, character, general
420
+ rating, character, general = postprocess_results(
421
+ results, general_threshold, character_threshold
422
+ )
423
+
424
+ prompt = gen_prompt(
425
+ list(rating.keys()), list(character.keys()), list(general.keys())
426
+ )
427
+
428
+ output_series_tag = ""
429
+ output_series_list = character_list_to_series_list(character.keys())
430
+ if output_series_list:
431
+ output_series_tag = output_series_list[0]
432
+ else:
433
+ output_series_tag = ""
434
+
435
+ return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True),
436
+
437
+
438
+ def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3, character_threshold: float = 0.8):
439
+ if algo and not "Use WD Tagger" in algo:
440
+ return "", "", input_tags, gr.update(interactive=True),
441
+ return predict_tags(image, general_threshold, character_threshold)
442
+
443
+
444
+ def compose_prompt_to_copy(character: str, series: str, general: str):
445
+ characters = character.split(",") if character else []
446
+ serieses = series.split(",") if series else []
447
+ generals = general.split(",") if general else []
448
+ tags = characters + serieses + generals
449
+ cprompt = ",".join(tags) if tags else ""
450
+ return cprompt
tags.txt ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dartrs.v2 import AspectRatioTag, LengthTag, RatingTag, IdentityTag
3
+
4
+
5
+ V2_ASPECT_RATIO_OPTIONS: list[AspectRatioTag] = [
6
+ "ultra_wide",
7
+ "wide",
8
+ "square",
9
+ "tall",
10
+ "ultra_tall",
11
+ ]
12
+ V2_RATING_OPTIONS: list[RatingTag] = [
13
+ "sfw",
14
+ "general",
15
+ "sensitive",
16
+ "nsfw",
17
+ "questionable",
18
+ "explicit",
19
+ ]
20
+ V2_LENGTH_OPTIONS: list[LengthTag] = [
21
+ "very_short",
22
+ "short",
23
+ "medium",
24
+ "long",
25
+ "very_long",
26
+ ]
27
+ V2_IDENTITY_OPTIONS: list[IdentityTag] = [
28
+ "none",
29
+ "lax",
30
+ "strict",
31
+ ]
32
+
33
+
34
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
35
+ def gradio_copy_text(_text: None):
36
+ gr.Info("Copied!")
37
+
38
+
39
+ COPY_ACTION_JS = """\
40
+ (inputs, _outputs) => {
41
+ // inputs is the string value of the input_text
42
+ if (inputs.trim() !== "") {
43
+ navigator.clipboard.writeText(inputs);
44
+ }
45
+ }"""
v2.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import torch
4
+ from typing import Callable
5
+
6
+ from dartrs.v2 import (
7
+ V2Model,
8
+ MixtralModel,
9
+ MistralModel,
10
+ compose_prompt,
11
+ LengthTag,
12
+ AspectRatioTag,
13
+ RatingTag,
14
+ IdentityTag,
15
+ )
16
+ from dartrs.dartrs import DartTokenizer
17
+ from dartrs.utils import get_generation_config
18
+
19
+
20
+ import gradio as gr
21
+ from gradio.components import Component
22
+
23
+ try:
24
+ import spaces
25
+ except ImportError:
26
+
27
+ class spaces:
28
+ def GPU(*args, **kwargs):
29
+ return lambda x: x
30
+
31
+
32
+ from output import UpsamplingOutput
33
+
34
+
35
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
36
+
37
+ V2_ALL_MODELS = {
38
+ "dart-v2-moe-sft": {
39
+ "repo": "p1atdev/dart-v2-moe-sft",
40
+ "type": "sft",
41
+ "class": MixtralModel,
42
+ },
43
+ "dart-v2-sft": {
44
+ "repo": "p1atdev/dart-v2-sft",
45
+ "type": "sft",
46
+ "class": MistralModel,
47
+ },
48
+ }
49
+
50
+
51
+ def prepare_models(model_config: dict):
52
+ model_name = model_config["repo"]
53
+ tokenizer = DartTokenizer.from_pretrained(model_name, auth_token=HF_TOKEN)
54
+ model = model_config["class"].from_pretrained(model_name, auth_token=HF_TOKEN)
55
+
56
+ return {
57
+ "tokenizer": tokenizer,
58
+ "model": model,
59
+ }
60
+
61
+
62
+ def normalize_tags(tokenizer: DartTokenizer, tags: str):
63
+ """Just remove unk tokens."""
64
+ return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
65
+
66
+
67
+ @torch.no_grad()
68
+ def generate_tags(
69
+ model: V2Model,
70
+ tokenizer: DartTokenizer,
71
+ prompt: str,
72
+ ban_token_ids: list[int],
73
+ ):
74
+ output = model.generate(
75
+ get_generation_config(
76
+ prompt,
77
+ tokenizer=tokenizer,
78
+ temperature=1,
79
+ top_p=0.9,
80
+ top_k=100,
81
+ max_new_tokens=256,
82
+ ban_token_ids=ban_token_ids,
83
+ ),
84
+ )
85
+
86
+ return output
87
+
88
+
89
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
90
+ return (
91
+ [f"1{noun}"]
92
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
93
+ + [f"{maximum+1}+{noun}s"]
94
+ )
95
+
96
+
97
+ PEOPLE_TAGS = (
98
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
99
+ )
100
+
101
+
102
+ def gen_prompt_text(output: UpsamplingOutput):
103
+ # separate people tags (e.g. 1girl)
104
+ people_tags = []
105
+ other_general_tags = []
106
+
107
+ for tag in output.general_tags.split(","):
108
+ tag = tag.strip()
109
+ if tag in PEOPLE_TAGS:
110
+ people_tags.append(tag)
111
+ else:
112
+ other_general_tags.append(tag)
113
+
114
+ return ", ".join(
115
+ [
116
+ part.strip()
117
+ for part in [
118
+ *people_tags,
119
+ output.character_tags,
120
+ output.copyright_tags,
121
+ *other_general_tags,
122
+ output.upsampled_tags,
123
+ output.rating_tag,
124
+ ]
125
+ if part.strip() != ""
126
+ ]
127
+ )
128
+
129
+
130
+ def elapsed_time_format(elapsed_time: float) -> str:
131
+ return f"Elapsed: {elapsed_time:.2f} seconds"
132
+
133
+
134
+ def parse_upsampling_output(
135
+ upsampler: Callable[..., UpsamplingOutput],
136
+ ):
137
+ def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
138
+ output = upsampler(*args)
139
+
140
+ return (
141
+ gen_prompt_text(output),
142
+ elapsed_time_format(output.elapsed_time),
143
+ gr.update(interactive=True),
144
+ gr.update(interactive=True),
145
+ )
146
+
147
+ return _parse_upsampling_output
148
+
149
+
150
+ class V2UI:
151
+ model_name: str | None = None
152
+ model: V2Model
153
+ tokenizer: DartTokenizer
154
+
155
+ input_components: list[Component] = []
156
+ generate_btn: gr.Button
157
+
158
+ def on_generate(
159
+ self,
160
+ model_name: str,
161
+ copyright_tags: str,
162
+ character_tags: str,
163
+ general_tags: str,
164
+ rating_tag: RatingTag,
165
+ aspect_ratio_tag: AspectRatioTag,
166
+ length_tag: LengthTag,
167
+ identity_tag: IdentityTag,
168
+ ban_tags: str,
169
+ *args,
170
+ ) -> UpsamplingOutput:
171
+ if self.model_name is None or self.model_name != model_name:
172
+ models = prepare_models(V2_ALL_MODELS[model_name])
173
+ self.model = models["model"]
174
+ self.tokenizer = models["tokenizer"]
175
+ self.model_name = model_name
176
+
177
+ # normalize tags
178
+ # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
179
+ # character_tags = normalize_tags(self.tokenizer, character_tags)
180
+ # general_tags = normalize_tags(self.tokenizer, general_tags)
181
+
182
+ ban_token_ids = self.tokenizer.encode(ban_tags.strip())
183
+
184
+ prompt = compose_prompt(
185
+ prompt=general_tags,
186
+ copyright=copyright_tags,
187
+ character=character_tags,
188
+ rating=rating_tag,
189
+ aspect_ratio=aspect_ratio_tag,
190
+ length=length_tag,
191
+ identity=identity_tag,
192
+ )
193
+
194
+ start = time.time()
195
+ upsampled_tags = generate_tags(
196
+ self.model,
197
+ self.tokenizer,
198
+ prompt,
199
+ ban_token_ids,
200
+ )
201
+ elapsed_time = time.time() - start
202
+
203
+ return UpsamplingOutput(
204
+ upsampled_tags=upsampled_tags,
205
+ copyright_tags=copyright_tags,
206
+ character_tags=character_tags,
207
+ general_tags=general_tags,
208
+ rating_tag=rating_tag,
209
+ aspect_ratio_tag=aspect_ratio_tag,
210
+ length_tag=length_tag,
211
+ identity_tag=identity_tag,
212
+ elapsed_time=elapsed_time,
213
+ )
214
+