furusu's picture
Update app.py
3b0f531 verified
raw
history blame
2.66 kB
import torch
import json
import gradio as gr
TITLE = "Danboru Tag Similarity"
DESCRIPTION = """
与えられたダンボールタグの類似度を計算します。\n
対応するタグのリストはFilesからそれぞれのテキストファイルを参照してください。(Dartと同じです)。\n
Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシャッフルして2エポック学習しました。\n
学習後のトークン埋め込みを元に計算しています。
"""
with open("num_to_token.json", "r") as f:
num_to_token = json.load(f)
with open("popular.txt", "r") as f:
populars = f.read().splitlines()
with open("character.txt", "r") as f:
characters = f.read().splitlines()
characters_populars = list(set(characters) & set(populars))
with open("copyright.txt", "r") as f:
copyrights = f.read().splitlines()
copyrights_populars = list(set(copyrights) & set(populars))
with open("general.txt", "r") as f:
generals = f.read().splitlines()
generals_populars = list(set(generals) & set(populars))
token_to_num = {v:k for k,v in num_to_token.items()}
token_embeddings = torch.load("token_embeddings.pt")
tags = sorted(list(num_to_token.values()))
def predict(target_tag, sort_by, category, popular):
if sort_by == "descending":
multiplier = 1
else:
multiplier = -1
target_embedding = token_embeddings[int(token_to_num[target_tag])].unsqueeze(0)
sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1)
results = {num_to_token[str(i)]:sims[i].item() * multiplier for i in range(len(num_to_token))}
if category == "general":
tag_list = generals if popular == "all" else generals_populars
elif category == "character":
tag_list = characters if popular == "all" else characters_populars
elif category == "copyright":
tag_list = copyrights if popular == "all" else copyrights_populars
else:
tag_list = results.keys() if popular == "all" else populars
return {k:results[k] for k in tag_list}
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Target tag", value="otoko no ko"),
gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending"),
gr.Dropdown(choices=["all", "general", "character", "copyright"], value="all", label="category"),
gr.Radio(choices=["all", "only_popular"], label="Only popular tag (count>=1000)", value="all"),
],
outputs=gr.Label(num_top_classes=50),
title=TITLE,
description=DESCRIPTION
)
demo.launch()