Spaces:
Runtime error
Runtime error
Commit
·
cd41390
1
Parent(s):
5cc903d
async version with more cachine
Browse files- app.py +21 -9
- requirements.txt +2 -0
app.py
CHANGED
@@ -3,23 +3,27 @@ import copy
|
|
3 |
import os
|
4 |
from dataclasses import asdict, dataclass
|
5 |
from datetime import datetime, timedelta
|
|
|
6 |
from json import JSONDecodeError
|
7 |
from typing import Any, Dict, List, Optional, Union
|
|
|
8 |
import gradio as gr
|
9 |
import httpx
|
10 |
import orjson
|
|
|
11 |
from cashews import NOT_NONE, cache
|
12 |
-
from httpx import AsyncClient
|
13 |
from huggingface_hub import hf_hub_url, logging
|
14 |
from huggingface_hub.utils import disable_progress_bars
|
15 |
from rich import print
|
16 |
from tqdm.auto import tqdm
|
17 |
-
from httpx import Client
|
18 |
-
from datetime import datetime, timedelta
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
)
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
disable_progress_bars()
|
@@ -58,6 +62,7 @@ async def _try_load_model_card(hub_id, client=None):
|
|
58 |
length = None
|
59 |
return card_text, length
|
60 |
|
|
|
61 |
def _try_parse_card_data(hub_json_data):
|
62 |
data = {}
|
63 |
keys = ["license", "language", "datasets"]
|
@@ -72,7 +77,7 @@ def _try_parse_card_data(hub_json_data):
|
|
72 |
return data
|
73 |
|
74 |
|
75 |
-
@dataclass
|
76 |
class ModelMetadata:
|
77 |
hub_id: str
|
78 |
tags: Optional[List[str]]
|
@@ -89,7 +94,7 @@ class ModelMetadata:
|
|
89 |
created_at: Optional[datetime] = None
|
90 |
|
91 |
@classmethod
|
92 |
-
@cache(ttl=
|
93 |
async def from_hub(cls, hub_id, client=None):
|
94 |
try:
|
95 |
if not client:
|
@@ -224,6 +229,7 @@ ALL_PIPELINES = {
|
|
224 |
}
|
225 |
|
226 |
|
|
|
227 |
def generate_task_scores_dict():
|
228 |
task_scores = {}
|
229 |
for task in ALL_PIPELINES:
|
@@ -262,6 +268,7 @@ def generate_task_scores_dict():
|
|
262 |
return task_scores
|
263 |
|
264 |
|
|
|
265 |
def generate_common_scores():
|
266 |
GENERIC_SCORES = copy.deepcopy(COMMON_SCORES)
|
267 |
GENERIC_SCORES["_max_score"] = sum(
|
@@ -274,6 +281,7 @@ SCORES = generate_task_scores_dict()
|
|
274 |
GENERIC_SCORES = generate_common_scores()
|
275 |
|
276 |
|
|
|
277 |
def _basic_check(data: Optional[ModelMetadata]):
|
278 |
score = 0
|
279 |
if data is None:
|
@@ -334,7 +342,7 @@ def create_query_url(query, skip=0):
|
|
334 |
return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
|
335 |
|
336 |
|
337 |
-
def get_results(query,sync_client=None) -> Dict[Any, Any]:
|
338 |
if not sync_client:
|
339 |
sync_client = Client(http2=True, headers=headers)
|
340 |
url = create_query_url(query)
|
@@ -461,6 +469,7 @@ def create_markdown(results): # TODO move to separate file
|
|
461 |
rows.append(row)
|
462 |
return "\n".join(rows)
|
463 |
|
|
|
464 |
async def get_result_card_snippet(result, query=None, client=None):
|
465 |
if not client:
|
466 |
client = AsyncClient(http2=True, headers=headers)
|
@@ -472,6 +481,7 @@ async def get_result_card_snippet(result, query=None, client=None):
|
|
472 |
result["text"] = "Could not load model card"
|
473 |
return result
|
474 |
|
|
|
475 |
@cache(ttl=timedelta(hours=3), condition=NOT_NONE)
|
476 |
async def get_result_card_snippets(results, query=None, client=None):
|
477 |
if not client:
|
@@ -483,8 +493,10 @@ async def get_result_card_snippets(results, query=None, client=None):
|
|
483 |
results = await asyncio.gather(*result_snippets)
|
484 |
return results
|
485 |
|
|
|
486 |
sync_client = Client(http2=True, headers=headers)
|
487 |
|
|
|
488 |
def _search_hub(
|
489 |
query: str,
|
490 |
min_score: Optional[int] = None,
|
|
|
3 |
import os
|
4 |
from dataclasses import asdict, dataclass
|
5 |
from datetime import datetime, timedelta
|
6 |
+
from functools import lru_cache
|
7 |
from json import JSONDecodeError
|
8 |
from typing import Any, Dict, List, Optional, Union
|
9 |
+
|
10 |
import gradio as gr
|
11 |
import httpx
|
12 |
import orjson
|
13 |
+
from cachetools import TTLCache, cached
|
14 |
from cashews import NOT_NONE, cache
|
15 |
+
from httpx import AsyncClient, Client
|
16 |
from huggingface_hub import hf_hub_url, logging
|
17 |
from huggingface_hub.utils import disable_progress_bars
|
18 |
from rich import print
|
19 |
from tqdm.auto import tqdm
|
|
|
|
|
20 |
|
21 |
+
CACHE_EXPIRY_TIME = timedelta(hours=3)
|
22 |
+
|
23 |
+
sync_cache = TTLCache(maxsize=200_000, ttl=CACHE_EXPIRY_TIME, timer=datetime.now)
|
24 |
+
|
25 |
+
|
26 |
+
cache.setup("mem://")
|
27 |
|
28 |
|
29 |
disable_progress_bars()
|
|
|
62 |
length = None
|
63 |
return card_text, length
|
64 |
|
65 |
+
|
66 |
def _try_parse_card_data(hub_json_data):
|
67 |
data = {}
|
68 |
keys = ["license", "language", "datasets"]
|
|
|
77 |
return data
|
78 |
|
79 |
|
80 |
+
@dataclass(eq=False)
|
81 |
class ModelMetadata:
|
82 |
hub_id: str
|
83 |
tags: Optional[List[str]]
|
|
|
94 |
created_at: Optional[datetime] = None
|
95 |
|
96 |
@classmethod
|
97 |
+
@cache(ttl=CACHE_EXPIRY_TIME, condition=NOT_NONE)
|
98 |
async def from_hub(cls, hub_id, client=None):
|
99 |
try:
|
100 |
if not client:
|
|
|
229 |
}
|
230 |
|
231 |
|
232 |
+
@lru_cache()
|
233 |
def generate_task_scores_dict():
|
234 |
task_scores = {}
|
235 |
for task in ALL_PIPELINES:
|
|
|
268 |
return task_scores
|
269 |
|
270 |
|
271 |
+
@lru_cache()
|
272 |
def generate_common_scores():
|
273 |
GENERIC_SCORES = copy.deepcopy(COMMON_SCORES)
|
274 |
GENERIC_SCORES["_max_score"] = sum(
|
|
|
281 |
GENERIC_SCORES = generate_common_scores()
|
282 |
|
283 |
|
284 |
+
@cached(sync_cache)
|
285 |
def _basic_check(data: Optional[ModelMetadata]):
|
286 |
score = 0
|
287 |
if data is None:
|
|
|
342 |
return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
|
343 |
|
344 |
|
345 |
+
def get_results(query, sync_client=None) -> Dict[Any, Any]:
|
346 |
if not sync_client:
|
347 |
sync_client = Client(http2=True, headers=headers)
|
348 |
url = create_query_url(query)
|
|
|
469 |
rows.append(row)
|
470 |
return "\n".join(rows)
|
471 |
|
472 |
+
|
473 |
async def get_result_card_snippet(result, query=None, client=None):
|
474 |
if not client:
|
475 |
client = AsyncClient(http2=True, headers=headers)
|
|
|
481 |
result["text"] = "Could not load model card"
|
482 |
return result
|
483 |
|
484 |
+
|
485 |
@cache(ttl=timedelta(hours=3), condition=NOT_NONE)
|
486 |
async def get_result_card_snippets(results, query=None, client=None):
|
487 |
if not client:
|
|
|
493 |
results = await asyncio.gather(*result_snippets)
|
494 |
return results
|
495 |
|
496 |
+
|
497 |
sync_client = Client(http2=True, headers=headers)
|
498 |
|
499 |
+
|
500 |
def _search_hub(
|
501 |
query: str,
|
502 |
min_score: Optional[int] = None,
|
requirements.txt
CHANGED
@@ -30,6 +30,8 @@ attrs==23.1.0
|
|
30 |
# jsonschema
|
31 |
backcall==0.2.0
|
32 |
# via ipython
|
|
|
|
|
33 |
cashews==6.2.0
|
34 |
# via -r requirements.in
|
35 |
certifi==2023.5.7
|
|
|
30 |
# jsonschema
|
31 |
backcall==0.2.0
|
32 |
# via ipython
|
33 |
+
cachetools==5.3.1
|
34 |
+
# via -r requirements.in
|
35 |
cashews==6.2.0
|
36 |
# via -r requirements.in
|
37 |
certifi==2023.5.7
|