Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from pathlib import Path | |
import gc | |
import re | |
import shutil | |
from utils import set_token, get_download_file, list_uniq | |
from stkey import read_safetensors_key, read_safetensors_metadata, validate_keys, write_safetensors_key | |
TEMP_DIR = "." | |
KEYS_DIR = "keys" | |
KEYS_FILES = [f"{KEYS_DIR}/sdxl_keys.txt"] | |
DEFAULT_KEYS_FILE = f"{KEYS_DIR}/sdxl_keys.txt" | |
def update_keys_files(): | |
global KEYS_FILES | |
files = [] | |
for file in Path(KEYS_DIR).glob("*.txt"): | |
files.append(str(file)) | |
KEYS_FILES = files | |
update_keys_files() | |
def upload_keys_file(path: str): | |
global KEYS_FILES | |
newpath = str(Path(KEYS_DIR, Path(path).stem + ".txt")) | |
if not Path(newpath).exists(): shutil.copy(str(Path(path)), newpath) | |
update_keys_files() | |
return gr.update(choices=KEYS_FILES) | |
def parse_urls(s): | |
url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+" | |
try: | |
urls = re.findall(url_pattern, s) | |
return list(urls) | |
except Exception: | |
return [] | |
def to_urls(l: list[str]): | |
return "\n".join(l) | |
def uniq_urls(s): | |
return to_urls(list_uniq(parse_urls(s))) | |
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)): | |
download_dir = TEMP_DIR | |
progress(0, desc=f"Start downloading... {dl_url}") | |
output_filename = get_download_file(download_dir, dl_url, civitai_key) | |
return output_filename | |
def get_stkey(filename: str, is_validate: bool=True, rfile: str=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)): | |
paths = [] | |
metadata = {} | |
keys = [] | |
missing = [] | |
added = [] | |
try: | |
progress(0, desc=f"Loading keys... {filename}") | |
keys = read_safetensors_key(filename) | |
if len(keys) == 0: raise Exception("No keys found.") | |
progress(0.5, desc=f"Checking keys... {filename}") | |
if write_safetensors_key(keys, str(Path(filename).stem + ".txt"), is_validate, rfile): | |
paths.append(str(Path(filename).stem + ".txt")) | |
paths.append(str(Path(filename).stem + "_missing.txt")) | |
paths.append(str(Path(filename).stem + "_added.txt")) | |
missing, added = validate_keys(keys, rfile) | |
metadata = read_safetensors_metadata(filename) | |
except Exception as e: | |
print(f"Error: Failed check {filename}. {e}") | |
gr.Warning(f"Error: Failed check {filename}. {e}") | |
finally: | |
Path(filename).unlink() | |
return paths, metadata, keys, missing, added | |
def stkey_gr(dl_url: str, civitai_key: str, hf_token: str, urls: list[str], files: list[str], | |
is_validate=True, rfile=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)): | |
if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token | |
set_token(hf_token) | |
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key | |
dl_urls = parse_urls(dl_url) | |
if not urls: urls = [] | |
if not files: files = [] | |
metadata = {} | |
keys = [] | |
missing = [] | |
added = [] | |
for u in dl_urls: | |
file = download_file(u, civitai_key) | |
if not Path(file).exists() or not Path(file).is_file(): continue | |
paths, metadata, keys, missing, added = get_stkey(file, is_validate, rfile) | |
if len(paths) != 0: files.extend(paths) | |
progress(1, desc="Processing...") | |
gc.collect() | |
return gr.update(value=urls, choices=urls), gr.update(value=files), gr.update(visible=False), metadata, keys, missing, added | |