John6666 commited on
Commit
8f7570a
ยท
verified ยท
1 Parent(s): 5f83820

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +16 -15
  3. multit2i.py +40 -41
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ๐Ÿ–ผ๏ธ๐Ÿ“…
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  short_description: Text-to-Image
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.13.1
8
  app_file: app.py
9
  pinned: false
10
  short_description: Text-to-Image
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from model import models
3
  from multit2i import (load_models, infer_fn, infer_rand_fn, save_gallery,
4
- change_model, warm_model, get_model_info_md, loaded_models,
5
  get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
6
  get_recom_prompt_type, set_recom_prompt_preset, get_tag_type, randomize_seed, translate_to_en)
7
  from tagger.tagger import (predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
@@ -14,8 +14,11 @@ from tagger.utils import (V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
14
  max_images = 6
15
  MAX_SEED = 2**32-1
16
  load_models(models)
 
17
 
18
  css = """
 
 
19
  .model_info { text-align: center; }
20
  .output { width=112px; height=112px; max_width=112px; max_height=112px; !important; }
21
  .gallery { min_width=512px; min_height=512px; max_height=1024px; !important; }
@@ -29,24 +32,24 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
29
  with gr.Accordion("Prompt from Image File", open=False):
30
  tagger_image = gr.Image(label="Input image", type="pil", format="png", sources=["upload", "clipboard"], height=256)
31
  with gr.Accordion(label="Advanced options", open=False):
32
- with gr.Row():
33
  tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
34
  tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
35
  tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
36
- with gr.Row():
37
  tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
38
  tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
39
  tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
40
  tagger_generate_from_image = gr.Button(value="Generate Tags from Image", variant="secondary")
41
  with gr.Accordion("Prompt Transformer", open=False):
42
- with gr.Row():
43
  v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
44
  v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
45
- with gr.Row():
46
  v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
47
  v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
48
  v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
49
- with gr.Row():
50
  v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
51
  v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
52
  v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
@@ -56,26 +59,26 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
56
  prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
57
  with gr.Accordion("Advanced options", open=False):
58
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="")
59
- with gr.Row():
60
  width = gr.Slider(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
61
  height = gr.Slider(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
62
  steps = gr.Slider(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=0)
63
- with gr.Row():
64
  cfg = gr.Slider(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=0)
65
  seed = gr.Slider(label="Seed", info="Randomize Seed if -1.", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
66
  seed_rand = gr.Button("Randomize Seed ๐ŸŽฒ", size="sm", variant="secondary")
67
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
68
- with gr.Row():
69
  positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
70
  positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
71
  negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
72
  negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
73
- with gr.Row():
74
  image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=2)
75
  trans_prompt = gr.Button(value="Translate ๐Ÿ“", variant="secondary", size="sm", scale=2)
76
  clear_prompt = gr.Button(value="Clear ๐Ÿ—‘๏ธ", variant="secondary", size="sm", scale=1)
77
 
78
- with gr.Row():
79
  run_button = gr.Button("Generate Image", variant="primary", scale=6)
80
  random_button = gr.Button("Random Model ๐ŸŽฒ", variant="secondary", scale=3)
81
  #stop_button = gr.Button('Stop', variant="stop", interactive=False, scale=1)
@@ -121,7 +124,6 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
121
  image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
122
  with gr.Column():
123
  result_metadata = gr.Textbox(label="Metadata", show_label=True, show_copy_button=True, interactive=False, container=True, max_lines=99)
124
-
125
  image_metadata.change(
126
  fn=extract_exif_data,
127
  inputs=[image_metadata],
@@ -132,10 +134,9 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
132
  [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood),
133
  [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL),
134
  [Yntec/Diffusion80XX](https://huggingface.co/spaces/Yntec/Diffusion80XX).
135
- """
136
- )
137
  gr.DuplicateButton(value="Duplicate Space")
138
- gr.Markdown(f"Just a few edits to *model.py* are all it takes to complete your own collection.")
139
 
140
  #gr.on(triggers=[run_button.click, prompt.submit, random_button.click], fn=lambda: gr.update(interactive=True), inputs=None, outputs=stop_button, show_api=False)
141
  model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)\
 
1
  import gradio as gr
2
  from model import models
3
  from multit2i import (load_models, infer_fn, infer_rand_fn, save_gallery,
4
+ change_model, warm_model, get_model_info_md, loaded_models, warm_models,
5
  get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
6
  get_recom_prompt_type, set_recom_prompt_preset, get_tag_type, randomize_seed, translate_to_en)
7
  from tagger.tagger import (predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
 
14
  max_images = 6
15
  MAX_SEED = 2**32-1
16
  load_models(models)
17
+ warm_models(models[0:max_images])
18
 
19
  css = """
20
+ .title { font-size: 3em; align-items: center; text-align: center; }
21
+ .info { align-items: center; text-align: center; }
22
  .model_info { text-align: center; }
23
  .output { width=112px; height=112px; max_width=112px; max_height=112px; !important; }
24
  .gallery { min_width=512px; min_height=512px; max_height=1024px; !important; }
 
32
  with gr.Accordion("Prompt from Image File", open=False):
33
  tagger_image = gr.Image(label="Input image", type="pil", format="png", sources=["upload", "clipboard"], height=256)
34
  with gr.Accordion(label="Advanced options", open=False):
35
+ with gr.Row(equal_height=True):
36
  tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
37
  tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
38
  tagger_tag_type = gr.Radio(label="Convert tags to", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
39
+ with gr.Row(equal_height=True):
40
  tagger_recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
41
  tagger_keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
42
  tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger"])
43
  tagger_generate_from_image = gr.Button(value="Generate Tags from Image", variant="secondary")
44
  with gr.Accordion("Prompt Transformer", open=False):
45
+ with gr.Row(equal_height=True):
46
  v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2)
47
  v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
48
+ with gr.Row(equal_height=True):
49
  v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
50
  v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
51
  v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
52
+ with gr.Row(equal_height=True):
53
  v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
54
  v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
55
  v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
 
59
  prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
60
  with gr.Accordion("Advanced options", open=False):
61
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="")
62
+ with gr.Row(equal_height=True):
63
  width = gr.Slider(label="Width", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
64
  height = gr.Slider(label="Height", info="If 0, the default value is used.", maximum=1216, step=32, value=0)
65
  steps = gr.Slider(label="Number of inference steps", info="If 0, the default value is used.", maximum=100, step=1, value=0)
66
+ with gr.Row(equal_height=True):
67
  cfg = gr.Slider(label="Guidance scale", info="If 0, the default value is used.", maximum=30.0, step=0.1, value=0)
68
  seed = gr.Slider(label="Seed", info="Randomize Seed if -1.", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
69
  seed_rand = gr.Button("Randomize Seed ๐ŸŽฒ", size="sm", variant="secondary")
70
  recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
71
+ with gr.Row(equal_height=True):
72
  positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
73
  positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
74
  negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
75
  negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
76
+ with gr.Row(equal_height=True):
77
  image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=2)
78
  trans_prompt = gr.Button(value="Translate ๐Ÿ“", variant="secondary", size="sm", scale=2)
79
  clear_prompt = gr.Button(value="Clear ๐Ÿ—‘๏ธ", variant="secondary", size="sm", scale=1)
80
 
81
+ with gr.Row(equal_height=True):
82
  run_button = gr.Button("Generate Image", variant="primary", scale=6)
83
  random_button = gr.Button("Random Model ๐ŸŽฒ", variant="secondary", scale=3)
84
  #stop_button = gr.Button('Stop', variant="stop", interactive=False, scale=1)
 
124
  image_metadata = gr.Image(label="Image with metadata", type="pil", sources=["upload"])
125
  with gr.Column():
126
  result_metadata = gr.Textbox(label="Metadata", show_label=True, show_copy_button=True, interactive=False, container=True, max_lines=99)
 
127
  image_metadata.change(
128
  fn=extract_exif_data,
129
  inputs=[image_metadata],
 
134
  [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood),
135
  [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL),
136
  [Yntec/Diffusion80XX](https://huggingface.co/spaces/Yntec/Diffusion80XX).
137
+ """, elem_classes="info")
 
138
  gr.DuplicateButton(value="Duplicate Space")
139
+ gr.Markdown(f"Just a few edits to *model.py* are all it takes to complete your own collection.", elem_classes="info")
140
 
141
  #gr.on(triggers=[run_button.click, prompt.submit, random_button.click], fn=lambda: gr.update(interactive=True), inputs=None, outputs=stop_button, show_api=False)
142
  model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)\
multit2i.py CHANGED
@@ -8,7 +8,7 @@ import os
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None # If private or gated models aren't used, ENV setting is unnecessary.
10
  server_timeout = 600
11
- inference_timeout = 300
12
 
13
 
14
  lock = RLock()
@@ -52,7 +52,7 @@ def is_loadable(model_name: str, force_gpu: bool = False):
52
  return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
53
 
54
 
55
- def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False):
56
  from huggingface_hub import HfApi
57
  api = HfApi(token=HF_TOKEN)
58
  default_tags = ["diffusers"]
@@ -61,13 +61,13 @@ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="l
61
  models = []
62
  try:
63
  model_infos = api.list_models(author=author, #task="text-to-image",
64
- tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
65
  except Exception as e:
66
  print(f"Error: Failed to list models.")
67
  print(e)
68
  return models
69
  for model in model_infos:
70
- if not model.private and not model.gated or HF_TOKEN is not None:
71
  loadable = is_loadable(model.id, force_gpu) if check_status else True
72
  if not_tag and not_tag in model.tags or not loadable or "not-for-all-audiences" in model.tags: continue
73
  models.append(model.id)
@@ -104,8 +104,8 @@ def get_t2i_model_info_dict(repo_id: str):
104
  info["likes"] = model.likes
105
  info["last_modified"] = model.last_modified.strftime("lastmod: %Y-%m-%d")
106
  un_tags = ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']
107
- descs = [info["ver"]] + list_sub(info["tags"], un_tags) + [f'DLs: {info["downloads"]}'] + [f'โค: {info["likes"]}'] + [info["last_modified"]]
108
- info["md"] = f'Model Info: {", ".join(descs)} [Model Repo]({info["url"]})'
109
  return info
110
 
111
 
@@ -160,8 +160,9 @@ def load_from_model(model_name: str, hf_token: str | Literal[False] | None = Non
160
  p = response.json().get("pipeline_tag")
161
  if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.")
162
  headers["X-Wait-For-Model"] = "true"
163
- client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
164
- token=hf_token, timeout=server_timeout)
 
165
  inputs = gr.components.Textbox(label="Input")
166
  outputs = gr.components.Image(label="Output")
167
  fn = client.text_to_image
@@ -170,9 +171,10 @@ def load_from_model(model_name: str, hf_token: str | Literal[False] | None = Non
170
  try:
171
  data = fn(*data, **kwargs) # type: ignore
172
  except huggingface_hub.utils.HfHubHTTPError as e:
173
- if "429" in str(e):
174
- raise TooManyRequestsError() from e
175
  except Exception as e:
 
176
  raise Exception() from e
177
  return data
178
 
@@ -210,29 +212,29 @@ def load_model(model_name: str):
210
  def load_model_api(model_name: str):
211
  global loaded_models
212
  global model_info_dict
213
- if model_name in loaded_models.keys(): return loaded_models[model_name]
214
  try:
215
- client = InferenceClient(timeout=5)
216
- status = client.get_model_status(model_name, token=HF_TOKEN)
217
- if status is None or status.framework != "diffusers" or status.state not in ["Loadable", "Loaded"]:
 
218
  print(f"Failed to load by API: {model_name}")
219
- return None
220
  else:
221
- loaded_models[model_name] = InferenceClient(model_name, token=HF_TOKEN, timeout=server_timeout)
 
222
  print(f"Loaded by API: {model_name}")
223
  except Exception as e:
224
- if model_name in loaded_models.keys(): del loaded_models[model_name]
225
  print(f"Failed to load by API: {model_name}")
226
  print(e)
227
- return None
228
  try:
229
- model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
230
- print(f"Assigned by API: {model_name}")
 
231
  except Exception as e:
232
  if model_name in model_info_dict.keys(): del model_info_dict[model_name]
233
  print(f"Failed to assigned by API: {model_name}")
234
  print(e)
235
- return loaded_models[model_name]
236
 
237
 
238
  def load_models(models: list):
@@ -270,8 +272,8 @@ positive_all = list_uniq(positive_all)
270
  def recom_prompt(prompt: str = "", neg_prompt: str = "", pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
271
  def flatten(src):
272
  return [item for row in src for item in row]
273
- prompts = to_list(prompt)
274
- neg_prompts = to_list(neg_prompt)
275
  prompts = list_sub(prompts, positive_all)
276
  neg_prompts = list_sub(neg_prompts, negative_all)
277
  last_empty_p = [""] if not prompts and type != "None" else []
@@ -287,7 +289,6 @@ def recom_prompt(prompt: str = "", neg_prompt: str = "", pos_pre: list = [], pos
287
 
288
  recom_prompt_type = {
289
  "None": ([], [], [], []),
290
- "Auto": ([], [], [], []),
291
  "Common": ([], ["Common"], [], ["Common"]),
292
  "Animagine": ([], ["Common", "Anime"], [], ["Common"]),
293
  "Pony": (["Pony"], ["Common"], ["Pony"], ["Common"]),
@@ -296,11 +297,7 @@ recom_prompt_type = {
296
  }
297
 
298
 
299
- enable_auto_recom_prompt = False
300
  def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
301
- global enable_auto_recom_prompt
302
- if type == "Auto": enable_auto_recom_prompt = True
303
- else: enable_auto_recom_prompt = False
304
  pos_pre, pos_suf, neg_pre, neg_suf = recom_prompt_type.get(type, ([], [], [], []))
305
  return recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
306
 
@@ -311,9 +308,7 @@ def set_recom_prompt_preset(type: str = "None"):
311
 
312
 
313
  def get_recom_prompt_type():
314
- type = list(recom_prompt_type.keys())
315
- type.remove("Auto")
316
- return type
317
 
318
 
319
  def get_positive_prefix():
@@ -356,11 +351,16 @@ def warm_model(model_name: str):
356
  if model:
357
  try:
358
  print(f"Warming model: {model_name}")
359
- infer_body(model, " ")
360
  except Exception as e:
361
  print(e)
362
 
363
 
 
 
 
 
 
364
  # https://huggingface.co/docs/api-inference/detailed_parameters
365
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
366
  def infer_body(client: InferenceClient | gr.Interface | object, model_str: str, prompt: str, neg_prompt: str = "",
@@ -375,21 +375,22 @@ def infer_body(client: InferenceClient | gr.Interface | object, model_str: str,
375
  else: kwargs["seed"] = seed
376
  try:
377
  if isinstance(client, InferenceClient):
378
- image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
379
  elif isinstance(client, gr.Interface):
380
- image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs, token=HF_TOKEN)
 
381
  else: return None
382
  if isinstance(image, tuple): return None
383
  return save_image(image, png_path, model_str, prompt, neg_prompt, height, width, steps, cfg, seed)
384
  except Exception as e:
385
  print(e)
386
- raise Exception() from e
387
 
388
 
389
  async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int = 0, width: int = 0,
390
  steps: int = 0, cfg: int = 0, seed: int = -1,
391
  save_path: str | None = None, timeout: float = inference_timeout):
392
- model = load_model(model_name)
393
  if not model: return None
394
  task = asyncio.create_task(asyncio.to_thread(infer_body, model, model_name, prompt, neg_prompt,
395
  height, width, steps, cfg, seed))
@@ -406,7 +407,7 @@ async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int =
406
  print(e)
407
  if not task.done(): task.cancel()
408
  result = None
409
- raise Exception() from e
410
  if task.done() and result is not None:
411
  with lock:
412
  image = rename_image(result, model_name, save_path)
@@ -418,8 +419,7 @@ async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int =
418
  def infer_fn(model_name: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0,
419
  steps: int = 0, cfg: int = 0, seed: int = -1,
420
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
421
- if model_name == 'NA':
422
- return None
423
  try:
424
  loop = asyncio.get_running_loop()
425
  except Exception:
@@ -442,8 +442,7 @@ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str = "", heig
442
  steps: int = 0, cfg: int = 0, seed: int = -1,
443
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
444
  import random
445
- if model_name_dummy == 'NA':
446
- return None
447
  random.seed()
448
  model_name = random.choice(list(loaded_models.keys()))
449
  try:
 
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None # If private or gated models aren't used, ENV setting is unnecessary.
10
  server_timeout = 600
11
+ inference_timeout = 600
12
 
13
 
14
  lock = RLock()
 
52
  return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
53
 
54
 
55
+ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False, public=False):
56
  from huggingface_hub import HfApi
57
  api = HfApi(token=HF_TOKEN)
58
  default_tags = ["diffusers"]
 
61
  models = []
62
  try:
63
  model_infos = api.list_models(author=author, #task="text-to-image",
64
+ tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
65
  except Exception as e:
66
  print(f"Error: Failed to list models.")
67
  print(e)
68
  return models
69
  for model in model_infos:
70
+ if not model.private and not model.gated or (HF_TOKEN is not None and not public):
71
  loadable = is_loadable(model.id, force_gpu) if check_status else True
72
  if not_tag and not_tag in model.tags or not loadable or "not-for-all-audiences" in model.tags: continue
73
  models.append(model.id)
 
104
  info["likes"] = model.likes
105
  info["last_modified"] = model.last_modified.strftime("lastmod: %Y-%m-%d")
106
  un_tags = ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']
107
+ descs = [info["ver"]] + list_sub(info["tags"], un_tags) + [f'DLs: {info["downloads"]}'] + [f'๐Ÿ’•: {info["likes"]}'] + [info["last_modified"]]
108
+ info["md"] = f'{", ".join(descs)} [Model Repo]({info["url"]})'
109
  return info
110
 
111
 
 
160
  p = response.json().get("pipeline_tag")
161
  if p != "text-to-image": raise ModelNotFoundError(f"This model isn't for text-to-image or unsupported: {model_name}.")
162
  headers["X-Wait-For-Model"] = "true"
163
+ kwargs = {}
164
+ if hf_token is not None: kwargs["token"] = hf_token
165
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers, timeout=server_timeout, **kwargs)
166
  inputs = gr.components.Textbox(label="Input")
167
  outputs = gr.components.Image(label="Output")
168
  fn = client.text_to_image
 
171
  try:
172
  data = fn(*data, **kwargs) # type: ignore
173
  except huggingface_hub.utils.HfHubHTTPError as e:
174
+ print(e)
175
+ if "429" in str(e): raise TooManyRequestsError() from e
176
  except Exception as e:
177
+ print(e)
178
  raise Exception() from e
179
  return data
180
 
 
212
  def load_model_api(model_name: str):
213
  global loaded_models
214
  global model_info_dict
 
215
  try:
216
+ loaded = False
217
+ client = InferenceClient(timeout=5, token=HF_TOKEN)
218
+ status = client.get_model_status(model_name)
219
+ if status is None or status.framework != "diffusers" or not status.loaded or status.state not in ["Loadable", "Loaded"]:
220
  print(f"Failed to load by API: {model_name}")
 
221
  else:
222
+ loaded_models[model_name] = InferenceClient(model_name, timeout=server_timeout)
223
+ loaded = True
224
  print(f"Loaded by API: {model_name}")
225
  except Exception as e:
 
226
  print(f"Failed to load by API: {model_name}")
227
  print(e)
228
+ loaded = False
229
  try:
230
+ if loaded:
231
+ model_info_dict[model_name] = get_t2i_model_info_dict(model_name)
232
+ print(f"Assigned by API: {model_name}")
233
  except Exception as e:
234
  if model_name in model_info_dict.keys(): del model_info_dict[model_name]
235
  print(f"Failed to assigned by API: {model_name}")
236
  print(e)
237
+ return loaded_models[model_name] if model_name in loaded_models.keys() else None
238
 
239
 
240
  def load_models(models: list):
 
272
  def recom_prompt(prompt: str = "", neg_prompt: str = "", pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
273
  def flatten(src):
274
  return [item for row in src for item in row]
275
+ prompts = to_list(prompt) if prompt else []
276
+ neg_prompts = to_list(neg_prompt) if neg_prompt else []
277
  prompts = list_sub(prompts, positive_all)
278
  neg_prompts = list_sub(neg_prompts, negative_all)
279
  last_empty_p = [""] if not prompts and type != "None" else []
 
289
 
290
  recom_prompt_type = {
291
  "None": ([], [], [], []),
 
292
  "Common": ([], ["Common"], [], ["Common"]),
293
  "Animagine": ([], ["Common", "Anime"], [], ["Common"]),
294
  "Pony": (["Pony"], ["Common"], ["Pony"], ["Common"]),
 
297
  }
298
 
299
 
 
300
  def insert_recom_prompt(prompt: str = "", neg_prompt: str = "", type: str = "None"):
 
 
 
301
  pos_pre, pos_suf, neg_pre, neg_suf = recom_prompt_type.get(type, ([], [], [], []))
302
  return recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
303
 
 
308
 
309
 
310
  def get_recom_prompt_type():
311
+ return list(recom_prompt_type.keys())
 
 
312
 
313
 
314
  def get_positive_prefix():
 
351
  if model:
352
  try:
353
  print(f"Warming model: {model_name}")
354
+ infer_body(model, model_name, " ")
355
  except Exception as e:
356
  print(e)
357
 
358
 
359
+ def warm_models(models: list[str]):
360
+ for model in models:
361
+ asyncio.new_event_loop().run_in_executor(None, warm_model, model)
362
+
363
+
364
  # https://huggingface.co/docs/api-inference/detailed_parameters
365
  # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
366
  def infer_body(client: InferenceClient | gr.Interface | object, model_str: str, prompt: str, neg_prompt: str = "",
 
375
  else: kwargs["seed"] = seed
376
  try:
377
  if isinstance(client, InferenceClient):
378
+ image = client.text_to_image(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
379
  elif isinstance(client, gr.Interface):
380
+ if HF_TOKEN is not None: kwargs["token"] = HF_TOKEN
381
+ image = client.fn(prompt=prompt, negative_prompt=neg_prompt, **kwargs)
382
  else: return None
383
  if isinstance(image, tuple): return None
384
  return save_image(image, png_path, model_str, prompt, neg_prompt, height, width, steps, cfg, seed)
385
  except Exception as e:
386
  print(e)
387
+ raise Exception(e) from e
388
 
389
 
390
  async def infer(model_name: str, prompt: str, neg_prompt: str ="", height: int = 0, width: int = 0,
391
  steps: int = 0, cfg: int = 0, seed: int = -1,
392
  save_path: str | None = None, timeout: float = inference_timeout):
393
+ model = load_model_api(model_name)
394
  if not model: return None
395
  task = asyncio.create_task(asyncio.to_thread(infer_body, model, model_name, prompt, neg_prompt,
396
  height, width, steps, cfg, seed))
 
407
  print(e)
408
  if not task.done(): task.cancel()
409
  result = None
410
+ raise Exception(e) from e
411
  if task.done() and result is not None:
412
  with lock:
413
  image = rename_image(result, model_name, save_path)
 
419
  def infer_fn(model_name: str, prompt: str, neg_prompt: str = "", height: int = 0, width: int = 0,
420
  steps: int = 0, cfg: int = 0, seed: int = -1,
421
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
422
+ if model_name in ["NA", ""]: return gr.update()
 
423
  try:
424
  loop = asyncio.get_running_loop()
425
  except Exception:
 
442
  steps: int = 0, cfg: int = 0, seed: int = -1,
443
  pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], save_path: str | None = None):
444
  import random
445
+ if model_name_dummy in ["NA", ""]: return gr.update()
 
446
  random.seed()
447
  model_name = random.choice(list(loaded_models.keys()))
448
  try: