Spaces:
Running
Running
| import torch | |
| import json | |
| import gradio as gr | |
| TITLE = "Danboru Tag Similarity" | |
| DESCRIPTION = """ | |
| 与えられたダンボールタグの類似度を計算します。\n | |
| 対応するタグのリストはFilesからそれぞれのテキストファイルを参照してください。(Dartと同じです)。\n | |
| Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシャッフルして2エポック学習しました。\n | |
| 学習後のトークン埋め込みを元に計算しています。 | |
| """ | |
| with open("num_to_token.json", "r") as f: | |
| num_to_token = json.load(f) | |
| with open("popular.txt", "r") as f: | |
| populars = f.read().splitlines() | |
| with open("character.txt", "r") as f: | |
| characters = f.read().splitlines() | |
| characters_populars = list(set(characters) & set(populars)) | |
| with open("copyright.txt", "r") as f: | |
| copyrights = f.read().splitlines() | |
| copyrights_populars = list(set(copyrights) & set(populars)) | |
| with open("general.txt", "r") as f: | |
| generals = f.read().splitlines() | |
| generals_populars = list(set(generals) & set(populars)) | |
| token_to_num = {v:k for k,v in num_to_token.items()} | |
| token_embeddings = torch.load("token_embeddings.pt") | |
| tags = sorted(list(num_to_token.values())) | |
| def predict(target_tag, sort_by, category, popular): | |
| if sort_by == "descending": | |
| multiplier = 1 | |
| else: | |
| multiplier = -1 | |
| target_embedding = token_embeddings[int(token_to_num[target_tag])].unsqueeze(0) | |
| sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1) | |
| results = {num_to_token[str(i)]:sims[i].item() * multiplier for i in range(len(num_to_token))} | |
| if category == "general": | |
| tag_list = generals if popular == "all" else generals_populars | |
| elif category == "character": | |
| tag_list = characters if popular == "all" else characters_populars | |
| elif category == "copyright": | |
| tag_list = copyrights if popular == "all" else copyrights_populars | |
| else: | |
| tag_list = results.keys() if popular == "all" else populars | |
| return {k:results[k] for k in tag_list} | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox(label="Target tag", value="otoko no ko"), | |
| gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending"), | |
| gr.Dropdown(choices=["all", "general", "character", "copyright"], value="all", label="category"), | |
| gr.Radio(choices=["all", "only_popular"], label="Only popular tag (count>=1000)", value="all"), | |
| ], | |
| outputs=gr.Label(num_top_classes=50), | |
| title=TITLE, | |
| description=DESCRIPTION | |
| ) | |
| demo.launch() | |