davanstrien HF staff commited on
Commit
cd41390
·
1 Parent(s): 5cc903d

async version with more cachine

Browse files
Files changed (2) hide show
  1. app.py +21 -9
  2. requirements.txt +2 -0
app.py CHANGED
@@ -3,23 +3,27 @@ import copy
3
  import os
4
  from dataclasses import asdict, dataclass
5
  from datetime import datetime, timedelta
 
6
  from json import JSONDecodeError
7
  from typing import Any, Dict, List, Optional, Union
 
8
  import gradio as gr
9
  import httpx
10
  import orjson
 
11
  from cashews import NOT_NONE, cache
12
- from httpx import AsyncClient
13
  from huggingface_hub import hf_hub_url, logging
14
  from huggingface_hub.utils import disable_progress_bars
15
  from rich import print
16
  from tqdm.auto import tqdm
17
- from httpx import Client
18
- from datetime import datetime, timedelta
19
 
20
- cache.setup(
21
- "mem://"
22
- )
 
 
 
23
 
24
 
25
  disable_progress_bars()
@@ -58,6 +62,7 @@ async def _try_load_model_card(hub_id, client=None):
58
  length = None
59
  return card_text, length
60
 
 
61
  def _try_parse_card_data(hub_json_data):
62
  data = {}
63
  keys = ["license", "language", "datasets"]
@@ -72,7 +77,7 @@ def _try_parse_card_data(hub_json_data):
72
  return data
73
 
74
 
75
- @dataclass
76
  class ModelMetadata:
77
  hub_id: str
78
  tags: Optional[List[str]]
@@ -89,7 +94,7 @@ class ModelMetadata:
89
  created_at: Optional[datetime] = None
90
 
91
  @classmethod
92
- @cache(ttl=timedelta(hours=3), condition=NOT_NONE)
93
  async def from_hub(cls, hub_id, client=None):
94
  try:
95
  if not client:
@@ -224,6 +229,7 @@ ALL_PIPELINES = {
224
  }
225
 
226
 
 
227
  def generate_task_scores_dict():
228
  task_scores = {}
229
  for task in ALL_PIPELINES:
@@ -262,6 +268,7 @@ def generate_task_scores_dict():
262
  return task_scores
263
 
264
 
 
265
  def generate_common_scores():
266
  GENERIC_SCORES = copy.deepcopy(COMMON_SCORES)
267
  GENERIC_SCORES["_max_score"] = sum(
@@ -274,6 +281,7 @@ SCORES = generate_task_scores_dict()
274
  GENERIC_SCORES = generate_common_scores()
275
 
276
 
 
277
  def _basic_check(data: Optional[ModelMetadata]):
278
  score = 0
279
  if data is None:
@@ -334,7 +342,7 @@ def create_query_url(query, skip=0):
334
  return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
335
 
336
 
337
- def get_results(query,sync_client=None) -> Dict[Any, Any]:
338
  if not sync_client:
339
  sync_client = Client(http2=True, headers=headers)
340
  url = create_query_url(query)
@@ -461,6 +469,7 @@ def create_markdown(results): # TODO move to separate file
461
  rows.append(row)
462
  return "\n".join(rows)
463
 
 
464
  async def get_result_card_snippet(result, query=None, client=None):
465
  if not client:
466
  client = AsyncClient(http2=True, headers=headers)
@@ -472,6 +481,7 @@ async def get_result_card_snippet(result, query=None, client=None):
472
  result["text"] = "Could not load model card"
473
  return result
474
 
 
475
  @cache(ttl=timedelta(hours=3), condition=NOT_NONE)
476
  async def get_result_card_snippets(results, query=None, client=None):
477
  if not client:
@@ -483,8 +493,10 @@ async def get_result_card_snippets(results, query=None, client=None):
483
  results = await asyncio.gather(*result_snippets)
484
  return results
485
 
 
486
  sync_client = Client(http2=True, headers=headers)
487
 
 
488
  def _search_hub(
489
  query: str,
490
  min_score: Optional[int] = None,
 
3
  import os
4
  from dataclasses import asdict, dataclass
5
  from datetime import datetime, timedelta
6
+ from functools import lru_cache
7
  from json import JSONDecodeError
8
  from typing import Any, Dict, List, Optional, Union
9
+
10
  import gradio as gr
11
  import httpx
12
  import orjson
13
+ from cachetools import TTLCache, cached
14
  from cashews import NOT_NONE, cache
15
+ from httpx import AsyncClient, Client
16
  from huggingface_hub import hf_hub_url, logging
17
  from huggingface_hub.utils import disable_progress_bars
18
  from rich import print
19
  from tqdm.auto import tqdm
 
 
20
 
21
+ CACHE_EXPIRY_TIME = timedelta(hours=3)
22
+
23
+ sync_cache = TTLCache(maxsize=200_000, ttl=CACHE_EXPIRY_TIME, timer=datetime.now)
24
+
25
+
26
+ cache.setup("mem://")
27
 
28
 
29
  disable_progress_bars()
 
62
  length = None
63
  return card_text, length
64
 
65
+
66
  def _try_parse_card_data(hub_json_data):
67
  data = {}
68
  keys = ["license", "language", "datasets"]
 
77
  return data
78
 
79
 
80
+ @dataclass(eq=False)
81
  class ModelMetadata:
82
  hub_id: str
83
  tags: Optional[List[str]]
 
94
  created_at: Optional[datetime] = None
95
 
96
  @classmethod
97
+ @cache(ttl=CACHE_EXPIRY_TIME, condition=NOT_NONE)
98
  async def from_hub(cls, hub_id, client=None):
99
  try:
100
  if not client:
 
229
  }
230
 
231
 
232
+ @lru_cache()
233
  def generate_task_scores_dict():
234
  task_scores = {}
235
  for task in ALL_PIPELINES:
 
268
  return task_scores
269
 
270
 
271
+ @lru_cache()
272
  def generate_common_scores():
273
  GENERIC_SCORES = copy.deepcopy(COMMON_SCORES)
274
  GENERIC_SCORES["_max_score"] = sum(
 
281
  GENERIC_SCORES = generate_common_scores()
282
 
283
 
284
+ @cached(sync_cache)
285
  def _basic_check(data: Optional[ModelMetadata]):
286
  score = 0
287
  if data is None:
 
342
  return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
343
 
344
 
345
+ def get_results(query, sync_client=None) -> Dict[Any, Any]:
346
  if not sync_client:
347
  sync_client = Client(http2=True, headers=headers)
348
  url = create_query_url(query)
 
469
  rows.append(row)
470
  return "\n".join(rows)
471
 
472
+
473
  async def get_result_card_snippet(result, query=None, client=None):
474
  if not client:
475
  client = AsyncClient(http2=True, headers=headers)
 
481
  result["text"] = "Could not load model card"
482
  return result
483
 
484
+
485
  @cache(ttl=timedelta(hours=3), condition=NOT_NONE)
486
  async def get_result_card_snippets(results, query=None, client=None):
487
  if not client:
 
493
  results = await asyncio.gather(*result_snippets)
494
  return results
495
 
496
+
497
  sync_client = Client(http2=True, headers=headers)
498
 
499
+
500
  def _search_hub(
501
  query: str,
502
  min_score: Optional[int] = None,
requirements.txt CHANGED
@@ -30,6 +30,8 @@ attrs==23.1.0
30
  # jsonschema
31
  backcall==0.2.0
32
  # via ipython
 
 
33
  cashews==6.2.0
34
  # via -r requirements.in
35
  certifi==2023.5.7
 
30
  # jsonschema
31
  backcall==0.2.0
32
  # via ipython
33
+ cachetools==5.3.1
34
+ # via -r requirements.in
35
  cashews==6.2.0
36
  # via -r requirements.in
37
  certifi==2023.5.7