Spaces:
Runtime error
Runtime error
Deploy full application code
Browse files- .gitignore +39 -0
- Dockerfile +8 -5
- app.py +3 -13
- app/api/__init__.py +7 -0
- app/api/datasets.py +151 -0
- app/core/celery_app.py +98 -0
- app/core/config.py +48 -0
- app/main.py +46 -0
- app/schemas/dataset.py +81 -0
- app/schemas/dataset_common.py +17 -0
- app/services/hf_datasets.py +501 -0
- app/services/redis_client.py +302 -0
- app/tasks/dataset_tasks.py +73 -0
- migrations/20250620000000_create_combined_datasets_table.sql +57 -0
- setup.py +8 -0
- tests/test_datasets.py +78 -0
- tests/test_datasets_api.py +88 -0
.gitignore
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Environment
|
| 24 |
+
.env
|
| 25 |
+
.venv
|
| 26 |
+
env/
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
|
| 30 |
+
# Logs
|
| 31 |
+
*.log
|
| 32 |
+
logs/
|
| 33 |
+
celery_worker_*.log
|
| 34 |
+
nohup.out
|
| 35 |
+
|
| 36 |
+
# Database
|
| 37 |
+
*.sqlite
|
| 38 |
+
*.db
|
| 39 |
+
celerybeat-schedule
|
Dockerfile
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
# Use the official Python 3.10.9 image
|
| 2 |
FROM python:3.10.9
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
|
| 10 |
-
# Install requirements.txt
|
| 11 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
| 14 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
# Use the official Python 3.10.9 image
|
| 2 |
FROM python:3.10.9
|
| 3 |
|
| 4 |
+
# Set the working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
|
| 7 |
+
# Copy the current directory contents into the container
|
| 8 |
+
COPY . .
|
| 9 |
|
| 10 |
+
# Install requirements.txt
|
| 11 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 12 |
|
| 13 |
+
# Install the application in development mode
|
| 14 |
+
RUN pip install -e .
|
| 15 |
+
|
| 16 |
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
| 17 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
CHANGED
|
@@ -1,17 +1,7 @@
|
|
| 1 |
-
from
|
| 2 |
-
import uvicorn
|
| 3 |
|
| 4 |
-
#
|
| 5 |
-
app = FastAPI(title="Collinear API")
|
| 6 |
-
|
| 7 |
-
@app.get("/")
|
| 8 |
-
async def root():
|
| 9 |
-
return {"message": "Welcome to Collinear API"}
|
| 10 |
-
|
| 11 |
-
@app.get("/health")
|
| 12 |
-
async def health():
|
| 13 |
-
return {"status": "healthy"}
|
| 14 |
|
| 15 |
if __name__ == "__main__":
|
| 16 |
-
|
| 17 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
|
| 1 |
+
from app.main import app
|
|
|
|
| 2 |
|
| 3 |
+
# This file is needed for Hugging Face Spaces to find the app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
if __name__ == "__main__":
|
| 6 |
+
import uvicorn
|
| 7 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from app.api.datasets import router as datasets_router
|
| 3 |
+
# from . import batch # Removed batch import
|
| 4 |
+
|
| 5 |
+
api_router = APIRouter()
|
| 6 |
+
api_router.include_router(datasets_router, tags=["datasets"])
|
| 7 |
+
# api_router.include_router(batch.router) # Removed batch router
|
app/api/datasets.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Query, HTTPException
|
| 2 |
+
from typing import List, Optional, Dict, Any, Set
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from fastapi.concurrency import run_in_threadpool
|
| 5 |
+
from app.services.hf_datasets import (
|
| 6 |
+
get_dataset_commits,
|
| 7 |
+
get_dataset_files,
|
| 8 |
+
get_file_url,
|
| 9 |
+
get_datasets_page_from_zset,
|
| 10 |
+
get_dataset_commits_async,
|
| 11 |
+
get_dataset_files_async,
|
| 12 |
+
get_file_url_async,
|
| 13 |
+
get_datasets_page_from_cache,
|
| 14 |
+
fetch_and_cache_all_datasets,
|
| 15 |
+
)
|
| 16 |
+
from app.services.redis_client import cache_get
|
| 17 |
+
import logging
|
| 18 |
+
import time
|
| 19 |
+
from fastapi.responses import JSONResponse
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
router = APIRouter(prefix="/datasets", tags=["datasets"])
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
SIZE_LOW = 100 * 1024 * 1024
|
| 26 |
+
SIZE_MEDIUM = 1024 * 1024 * 1024
|
| 27 |
+
|
| 28 |
+
class DatasetInfo(BaseModel):
|
| 29 |
+
id: str
|
| 30 |
+
name: Optional[str]
|
| 31 |
+
description: Optional[str]
|
| 32 |
+
size_bytes: Optional[int]
|
| 33 |
+
impact_level: Optional[str]
|
| 34 |
+
downloads: Optional[int]
|
| 35 |
+
likes: Optional[int]
|
| 36 |
+
tags: Optional[List[str]]
|
| 37 |
+
class Config:
|
| 38 |
+
extra = "ignore"
|
| 39 |
+
|
| 40 |
+
class PaginatedDatasets(BaseModel):
|
| 41 |
+
total: int
|
| 42 |
+
items: List[DatasetInfo]
|
| 43 |
+
|
| 44 |
+
class CommitInfo(BaseModel):
|
| 45 |
+
id: str
|
| 46 |
+
title: Optional[str]
|
| 47 |
+
message: Optional[str]
|
| 48 |
+
author: Optional[Dict[str, Any]]
|
| 49 |
+
date: Optional[str]
|
| 50 |
+
|
| 51 |
+
class CacheStatus(BaseModel):
|
| 52 |
+
last_update: Optional[str]
|
| 53 |
+
total_items: int
|
| 54 |
+
warming_up: bool
|
| 55 |
+
|
| 56 |
+
def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 57 |
+
seen: Set[str] = set()
|
| 58 |
+
unique_items = []
|
| 59 |
+
for item in items:
|
| 60 |
+
item_id = item.get("id")
|
| 61 |
+
if item_id and item_id not in seen:
|
| 62 |
+
seen.add(item_id)
|
| 63 |
+
unique_items.append(item)
|
| 64 |
+
return unique_items
|
| 65 |
+
|
| 66 |
+
@router.get("/cache-status", response_model=CacheStatus)
|
| 67 |
+
async def cache_status():
|
| 68 |
+
meta = await cache_get("hf:datasets:meta")
|
| 69 |
+
last_update = meta["last_update"] if meta and "last_update" in meta else None
|
| 70 |
+
total_items = meta["total_items"] if meta and "total_items" in meta else 0
|
| 71 |
+
warming_up = not bool(total_items)
|
| 72 |
+
return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)
|
| 73 |
+
|
| 74 |
+
@router.get("/", response_model=None)
|
| 75 |
+
async def list_datasets(
|
| 76 |
+
limit: int = Query(10, ge=1, le=1000),
|
| 77 |
+
offset: int = Query(0, ge=0),
|
| 78 |
+
search: str = Query(None, description="Search term for dataset id or description"),
|
| 79 |
+
sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
|
| 80 |
+
sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
|
| 81 |
+
):
|
| 82 |
+
# Fetch the full list from cache
|
| 83 |
+
result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering
|
| 84 |
+
if status != 200:
|
| 85 |
+
return JSONResponse(result, status_code=status)
|
| 86 |
+
items = result["items"]
|
| 87 |
+
# Filtering
|
| 88 |
+
if search:
|
| 89 |
+
items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
|
| 90 |
+
# Sorting
|
| 91 |
+
if sort_by:
|
| 92 |
+
items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
|
| 93 |
+
# Pagination
|
| 94 |
+
total = len(items)
|
| 95 |
+
page = items[offset:offset+limit]
|
| 96 |
+
total_pages = (total + limit - 1) // limit
|
| 97 |
+
current_page = (offset // limit) + 1
|
| 98 |
+
next_page = current_page + 1 if offset + limit < total else None
|
| 99 |
+
prev_page = current_page - 1 if current_page > 1 else None
|
| 100 |
+
return {
|
| 101 |
+
"total": total,
|
| 102 |
+
"current_page": current_page,
|
| 103 |
+
"total_pages": total_pages,
|
| 104 |
+
"next_page": next_page,
|
| 105 |
+
"prev_page": prev_page,
|
| 106 |
+
"items": page
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
@router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
|
| 110 |
+
async def get_commits(dataset_id: str):
|
| 111 |
+
"""
|
| 112 |
+
Get commit history for a dataset.
|
| 113 |
+
"""
|
| 114 |
+
try:
|
| 115 |
+
return await get_dataset_commits_async(dataset_id)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
log.error(f"Error fetching commits for {dataset_id}: {e}")
|
| 118 |
+
raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")
|
| 119 |
+
|
| 120 |
+
@router.get("/{dataset_id:path}/files", response_model=List[str])
|
| 121 |
+
async def list_files(dataset_id: str):
|
| 122 |
+
"""
|
| 123 |
+
List files in a dataset.
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
return await get_dataset_files_async(dataset_id)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
log.error(f"Error listing files for {dataset_id}: {e}")
|
| 129 |
+
raise HTTPException(status_code=404, detail=f"Could not list files: {e}")
|
| 130 |
+
|
| 131 |
+
@router.get("/{dataset_id:path}/file-url")
|
| 132 |
+
async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
|
| 133 |
+
"""
|
| 134 |
+
Get download URL for a file in a dataset.
|
| 135 |
+
"""
|
| 136 |
+
url = await get_file_url_async(dataset_id, filename, revision)
|
| 137 |
+
return {"download_url": url}
|
| 138 |
+
|
| 139 |
+
@router.get("/meta")
|
| 140 |
+
async def get_datasets_meta():
|
| 141 |
+
meta = await cache_get("hf:datasets:meta")
|
| 142 |
+
return meta if meta else {}
|
| 143 |
+
|
| 144 |
+
# Endpoint to trigger cache refresh manually (for admin/testing)
|
| 145 |
+
@router.post("/datasets/refresh-cache")
|
| 146 |
+
def refresh_cache():
|
| 147 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
| 148 |
+
if not token:
|
| 149 |
+
return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
|
| 150 |
+
count = fetch_and_cache_all_datasets(token)
|
| 151 |
+
return {"status": "ok", "cached": count}
|
app/core/celery_app.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Celery configuration for task processing."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from celery import Celery
|
| 5 |
+
from celery.signals import task_failure, task_success, worker_ready, worker_shutdown
|
| 6 |
+
|
| 7 |
+
from app.core.config import settings
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Celery configuration
|
| 13 |
+
celery_app = Celery(
|
| 14 |
+
"dataset_impacts",
|
| 15 |
+
broker=settings.REDIS_URL,
|
| 16 |
+
backend=settings.REDIS_URL,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Configure Celery settings
|
| 20 |
+
celery_app.conf.update(
|
| 21 |
+
task_serializer="json",
|
| 22 |
+
accept_content=["json"],
|
| 23 |
+
result_serializer="json",
|
| 24 |
+
timezone="UTC",
|
| 25 |
+
enable_utc=True,
|
| 26 |
+
worker_concurrency=settings.WORKER_CONCURRENCY,
|
| 27 |
+
task_acks_late=True, # Tasks are acknowledged after execution
|
| 28 |
+
task_reject_on_worker_lost=True, # Tasks are rejected if worker is terminated during execution
|
| 29 |
+
task_time_limit=3600, # 1 hour timeout per task
|
| 30 |
+
task_soft_time_limit=3000, # Soft timeout (30 minutes) - allows for graceful shutdown
|
| 31 |
+
worker_prefetch_multiplier=1, # Single prefetch - improves fair distribution of tasks
|
| 32 |
+
broker_connection_retry=True,
|
| 33 |
+
broker_connection_retry_on_startup=True,
|
| 34 |
+
broker_connection_max_retries=10,
|
| 35 |
+
broker_pool_limit=10, # Connection pool size
|
| 36 |
+
result_expires=60 * 60 * 24, # Results expire after 24 hours
|
| 37 |
+
task_track_started=True, # Track when tasks are started
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Set up task routes for different task types
|
| 41 |
+
celery_app.conf.task_routes = {
|
| 42 |
+
"app.tasks.dataset_tasks.*": {"queue": "dataset_impacts"},
|
| 43 |
+
"app.tasks.maintenance.*": {"queue": "maintenance"},
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Configure retry settings
|
| 47 |
+
celery_app.conf.task_default_retry_delay = 30 # 30 seconds
|
| 48 |
+
celery_app.conf.task_max_retries = 3
|
| 49 |
+
|
| 50 |
+
# Setup beat schedule for periodic tasks if enabled
|
| 51 |
+
celery_app.conf.beat_schedule = {
|
| 52 |
+
"cleanup-stale-tasks": {
|
| 53 |
+
"task": "app.tasks.maintenance.cleanup_stale_tasks",
|
| 54 |
+
"schedule": 3600.0, # Run every hour
|
| 55 |
+
},
|
| 56 |
+
"health-check": {
|
| 57 |
+
"task": "app.tasks.maintenance.health_check",
|
| 58 |
+
"schedule": 300.0, # Run every 5 minutes
|
| 59 |
+
},
|
| 60 |
+
"refresh-hf-datasets-cache": {
|
| 61 |
+
"task": "app.tasks.dataset_tasks.refresh_hf_datasets_cache",
|
| 62 |
+
"schedule": 3600.0, # Run every hour
|
| 63 |
+
},
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Signal handlers for monitoring and logging
|
| 67 |
+
@task_failure.connect
|
| 68 |
+
def task_failure_handler(sender=None, task_id=None, exception=None, **kwargs):
|
| 69 |
+
"""Log failed tasks."""
|
| 70 |
+
logger.error(f"Task {task_id} {sender.name} failed: {exception}")
|
| 71 |
+
|
| 72 |
+
@task_success.connect
|
| 73 |
+
def task_success_handler(sender=None, result=None, **kwargs):
|
| 74 |
+
"""Log successful tasks."""
|
| 75 |
+
task_name = sender.name if sender else "Unknown"
|
| 76 |
+
logger.info(f"Task {task_name} completed successfully")
|
| 77 |
+
|
| 78 |
+
@worker_ready.connect
|
| 79 |
+
def worker_ready_handler(**kwargs):
|
| 80 |
+
"""Log when worker is ready."""
|
| 81 |
+
logger.info(f"Celery worker ready: {kwargs.get('hostname')}")
|
| 82 |
+
|
| 83 |
+
@worker_shutdown.connect
|
| 84 |
+
def worker_shutdown_handler(**kwargs):
|
| 85 |
+
"""Log when worker is shutting down."""
|
| 86 |
+
logger.info(f"Celery worker shutting down: {kwargs.get('hostname')}")
|
| 87 |
+
|
| 88 |
+
def get_celery_app():
|
| 89 |
+
"""Get the Celery app instance."""
|
| 90 |
+
# Import all tasks to ensure they're registered
|
| 91 |
+
try:
|
| 92 |
+
# Using the improved app.tasks module which properly imports all tasks
|
| 93 |
+
import app.tasks
|
| 94 |
+
logger.info(f"Tasks successfully imported; registered {len(celery_app.tasks)} tasks")
|
| 95 |
+
except ImportError as e:
|
| 96 |
+
logger.error(f"Error importing tasks: {e}")
|
| 97 |
+
|
| 98 |
+
return celery_app
|
app/core/config.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Final, Optional
|
| 4 |
+
|
| 5 |
+
from pydantic import SecretStr, HttpUrl, Field
|
| 6 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 7 |
+
|
| 8 |
+
class Settings(BaseSettings):
|
| 9 |
+
"""
|
| 10 |
+
Core application settings.
|
| 11 |
+
Reads environment variables and .env file.
|
| 12 |
+
"""
|
| 13 |
+
# Supabase Settings
|
| 14 |
+
SUPABASE_URL: HttpUrl
|
| 15 |
+
SUPABASE_SERVICE_KEY: SecretStr
|
| 16 |
+
SUPABASE_ANON_KEY: SecretStr
|
| 17 |
+
SUPABASE_JWT_SECRET: Optional[SecretStr] = None # Optional for local dev
|
| 18 |
+
|
| 19 |
+
# Hugging Face API Token
|
| 20 |
+
HF_API_TOKEN: Optional[SecretStr] = None
|
| 21 |
+
|
| 22 |
+
# Redis settings
|
| 23 |
+
REDIS_URL: str = "redis://localhost:6379/0"
|
| 24 |
+
REDIS_PASSWORD: Optional[SecretStr] = None
|
| 25 |
+
|
| 26 |
+
# Toggle Redis cache layer
|
| 27 |
+
ENABLE_REDIS_CACHE: bool = True
|
| 28 |
+
|
| 29 |
+
# ──────────────────────────────── Security ────────────────────────────────
|
| 30 |
+
# JWT secret key. NEVER hard-code in source; override with env variable in production.
|
| 31 |
+
SECRET_KEY: SecretStr = Field("change-me", env="SECRET_KEY")
|
| 32 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(60 * 24 * 7, env="ACCESS_TOKEN_EXPIRE_MINUTES") # 1 week by default
|
| 33 |
+
|
| 34 |
+
# Worker settings
|
| 35 |
+
WORKER_CONCURRENCY: int = 10 # Increased from 5 for better parallel performance
|
| 36 |
+
|
| 37 |
+
# Batch processing chunk size for Celery dataset tasks
|
| 38 |
+
DATASET_BATCH_CHUNK_SIZE: int = 50
|
| 39 |
+
|
| 40 |
+
# Tell pydantic-settings to auto-load `.env` if present
|
| 41 |
+
model_config: Final = SettingsConfigDict(
|
| 42 |
+
env_file=".env",
|
| 43 |
+
case_sensitive=False,
|
| 44 |
+
extra="ignore"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Single, shared instance of settings
|
| 48 |
+
settings = Settings()
|
app/main.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import json
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from app.api import api_router
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
| 7 |
+
|
| 8 |
+
class JsonFormatter(logging.Formatter):
|
| 9 |
+
def format(self, record):
|
| 10 |
+
log_record = {
|
| 11 |
+
"level": record.levelname,
|
| 12 |
+
"time": self.formatTime(record, self.datefmt),
|
| 13 |
+
"name": record.name,
|
| 14 |
+
"message": record.getMessage(),
|
| 15 |
+
}
|
| 16 |
+
if record.exc_info:
|
| 17 |
+
log_record["exc_info"] = self.formatException(record.exc_info)
|
| 18 |
+
return json.dumps(log_record)
|
| 19 |
+
|
| 20 |
+
handler = logging.StreamHandler()
|
| 21 |
+
handler.setFormatter(JsonFormatter())
|
| 22 |
+
logging.basicConfig(level=logging.INFO, handlers=[handler])
|
| 23 |
+
|
| 24 |
+
app = FastAPI(title="Collinear API")
|
| 25 |
+
|
| 26 |
+
# Enable CORS for the frontend
|
| 27 |
+
frontend_origin = "http://localhost:5173"
|
| 28 |
+
app.add_middleware(
|
| 29 |
+
CORSMiddleware,
|
| 30 |
+
allow_origins=[frontend_origin],
|
| 31 |
+
allow_credentials=True,
|
| 32 |
+
allow_methods=["*"],
|
| 33 |
+
allow_headers=["*"],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 37 |
+
|
| 38 |
+
app.include_router(api_router, prefix="/api")
|
| 39 |
+
|
| 40 |
+
@app.get("/")
|
| 41 |
+
async def root():
|
| 42 |
+
return {"message": "Welcome to the Collinear Data Tool API"}
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
import uvicorn
|
| 46 |
+
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
app/schemas/dataset.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, List, Optional, Any
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
|
| 6 |
+
from app.schemas.dataset_common import ImpactLevel, DatasetMetrics
|
| 7 |
+
|
| 8 |
+
# Log for this module
|
| 9 |
+
log = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# Supported strategies for dataset combination
|
| 12 |
+
SUPPORTED_STRATEGIES = ["merge", "intersect", "filter"]
|
| 13 |
+
|
| 14 |
+
class ImpactAssessment(BaseModel):
|
| 15 |
+
dataset_id: str = Field(..., description="The ID of the dataset being assessed")
|
| 16 |
+
impact_level: ImpactLevel = Field(..., description="The impact level: low, medium, or high")
|
| 17 |
+
assessment_method: str = Field(
|
| 18 |
+
"unknown",
|
| 19 |
+
description="Method used to determine impact level (e.g., size_based, downloads_and_likes_based)"
|
| 20 |
+
)
|
| 21 |
+
metrics: DatasetMetrics = Field(
|
| 22 |
+
...,
|
| 23 |
+
description="Metrics used for impact assessment"
|
| 24 |
+
)
|
| 25 |
+
thresholds: Dict[str, Dict[str, str]] = Field(
|
| 26 |
+
{},
|
| 27 |
+
description="Thresholds used for determining impact levels (for reference)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
class DatasetInfo(BaseModel):
|
| 31 |
+
id: str
|
| 32 |
+
impact_level: Optional[ImpactLevel] = None
|
| 33 |
+
impact_assessment: Optional[Dict] = None
|
| 34 |
+
# Add other fields as needed
|
| 35 |
+
class Config:
|
| 36 |
+
extra = "allow" # Allow extra fields from the API
|
| 37 |
+
|
| 38 |
+
class DatasetBase(BaseModel):
|
| 39 |
+
name: str
|
| 40 |
+
description: Optional[str] = None
|
| 41 |
+
tags: Optional[List[str]] = None
|
| 42 |
+
|
| 43 |
+
class DatasetCreate(DatasetBase):
|
| 44 |
+
files: Optional[List[str]] = None
|
| 45 |
+
|
| 46 |
+
class DatasetUpdate(DatasetBase):
|
| 47 |
+
name: Optional[str] = None # Make fields optional for updates
|
| 48 |
+
|
| 49 |
+
class Dataset(DatasetBase):
|
| 50 |
+
id: int # or str depending on your ID format
|
| 51 |
+
owner_id: str # Assuming user IDs are strings
|
| 52 |
+
created_at: Optional[str] = None
|
| 53 |
+
updated_at: Optional[str] = None
|
| 54 |
+
class Config:
|
| 55 |
+
pass # Removed orm_mode = True since ORM is not used
|
| 56 |
+
|
| 57 |
+
class DatasetCombineRequest(BaseModel):
|
| 58 |
+
source_datasets: List[str] = Field(..., description="List of dataset IDs to combine")
|
| 59 |
+
name: str = Field(..., description="Name for the combined dataset")
|
| 60 |
+
description: Optional[str] = Field(None, description="Description for the combined dataset")
|
| 61 |
+
combination_strategy: str = Field("merge", description="Strategy to use when combining datasets (e.g., 'merge', 'intersect', 'filter')")
|
| 62 |
+
filter_criteria: Optional[Dict[str, Any]] = Field(None, description="Criteria for filtering when combining datasets")
|
| 63 |
+
|
| 64 |
+
class CombinedDataset(BaseModel):
|
| 65 |
+
id: str = Field(..., description="ID of the combined dataset")
|
| 66 |
+
name: str = Field(..., description="Name of the combined dataset")
|
| 67 |
+
description: Optional[str] = Field(None, description="Description of the combined dataset")
|
| 68 |
+
source_datasets: List[str] = Field(..., description="IDs of the source datasets")
|
| 69 |
+
created_at: datetime = Field(..., description="Creation timestamp")
|
| 70 |
+
created_by: str = Field(..., description="ID of the user who created this combined dataset")
|
| 71 |
+
impact_level: Optional[ImpactLevel] = Field(None, description="Calculated impact level of the combined dataset")
|
| 72 |
+
status: str = Field("processing", description="Status of the dataset combination process")
|
| 73 |
+
combination_strategy: str = Field(..., description="Strategy used when combining datasets")
|
| 74 |
+
metrics: Optional[DatasetMetrics] = Field(None, description="Metrics for the combined dataset")
|
| 75 |
+
storage_bucket_id: Optional[str] = Field(None, description="ID of the storage bucket containing dataset files")
|
| 76 |
+
storage_folder_path: Optional[str] = Field(None, description="Path to the dataset files within the bucket")
|
| 77 |
+
class Config:
|
| 78 |
+
extra = "allow" # Allow extra fields for flexibility
|
| 79 |
+
|
| 80 |
+
__all__ = ["ImpactLevel", "ImpactAssessment", "DatasetInfo", "DatasetMetrics",
|
| 81 |
+
"Dataset", "DatasetCreate", "DatasetUpdate", "DatasetCombineRequest", "CombinedDataset"]
|
app/schemas/dataset_common.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
# Define the impact level as an enum for better type safety
|
| 6 |
+
class ImpactLevel(str, Enum):
|
| 7 |
+
NA = "not_available" # New category for when size information is unavailable
|
| 8 |
+
LOW = "low"
|
| 9 |
+
MEDIUM = "medium"
|
| 10 |
+
HIGH = "high"
|
| 11 |
+
|
| 12 |
+
# Define metrics model for impact assessment
|
| 13 |
+
class DatasetMetrics(BaseModel):
|
| 14 |
+
size_bytes: Optional[int] = Field(None, description="Size of the dataset in bytes")
|
| 15 |
+
file_count: Optional[int] = Field(None, description="Number of files in the dataset")
|
| 16 |
+
downloads: Optional[int] = Field(None, description="Number of downloads (all time)")
|
| 17 |
+
likes: Optional[int] = Field(None, description="Number of likes")
|
app/services/hf_datasets.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import json
|
| 3 |
+
from typing import Any, List, Optional, Dict, Tuple
|
| 4 |
+
import requests
|
| 5 |
+
from huggingface_hub import HfApi
|
| 6 |
+
from app.core.config import settings
|
| 7 |
+
from app.schemas.dataset_common import ImpactLevel
|
| 8 |
+
from app.services.redis_client import sync_cache_set, sync_cache_get, generate_cache_key, get_redis_sync
|
| 9 |
+
import time
|
| 10 |
+
import asyncio
|
| 11 |
+
import redis
|
| 12 |
+
import gzip
|
| 13 |
+
from datetime import datetime, timezone
|
| 14 |
+
import os
|
| 15 |
+
from app.schemas.dataset import ImpactAssessment
|
| 16 |
+
from app.schemas.dataset_common import DatasetMetrics
|
| 17 |
+
import httpx
|
| 18 |
+
import redis.asyncio as aioredis
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
api = HfApi()
|
| 22 |
+
redis_client = redis.Redis(host="redis", port=6379, decode_responses=True)
|
| 23 |
+
|
| 24 |
+
# Thresholds for impact categorization
|
| 25 |
+
SIZE_THRESHOLD_LOW = 100 * 1024 * 1024 # 100 MB
|
| 26 |
+
SIZE_THRESHOLD_MEDIUM = 1024 * 1024 * 1024 # 1 GB
|
| 27 |
+
DOWNLOADS_THRESHOLD_LOW = 1000
|
| 28 |
+
DOWNLOADS_THRESHOLD_MEDIUM = 10000
|
| 29 |
+
LIKES_THRESHOLD_LOW = 10
|
| 30 |
+
LIKES_THRESHOLD_MEDIUM = 100
|
| 31 |
+
|
| 32 |
+
HF_API_URL = "https://huggingface.co/api/datasets"
|
| 33 |
+
DATASET_CACHE_TTL = 60 * 60 # 1 hour
|
| 34 |
+
|
| 35 |
+
# Redis and HuggingFace API setup
|
| 36 |
+
REDIS_KEY = "hf:datasets:all:compressed"
|
| 37 |
+
REDIS_META_KEY = "hf:datasets:meta"
|
| 38 |
+
REDIS_TTL = 60 * 60 # 1 hour
|
| 39 |
+
|
| 40 |
+
# Impact thresholds (in bytes)
|
| 41 |
+
SIZE_LOW = 100 * 1024 * 1024
|
| 42 |
+
SIZE_MEDIUM = 1024 * 1024 * 1024
|
| 43 |
+
|
| 44 |
+
def get_hf_token():
|
| 45 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
| 46 |
+
if not token:
|
| 47 |
+
raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
|
| 48 |
+
return token
|
| 49 |
+
|
| 50 |
+
def get_dataset_commits(dataset_id: str, limit: int = 20):
|
| 51 |
+
from huggingface_hub import HfApi
|
| 52 |
+
import logging
|
| 53 |
+
log = logging.getLogger(__name__)
|
| 54 |
+
api = HfApi()
|
| 55 |
+
log.info(f"[get_dataset_commits] Fetching commits for dataset_id={dataset_id}")
|
| 56 |
+
try:
|
| 57 |
+
commits = api.list_repo_commits(repo_id=dataset_id, repo_type="dataset")
|
| 58 |
+
log.info(f"[get_dataset_commits] Received {len(commits)} commits for {dataset_id}")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
log.error(f"[get_dataset_commits] Error fetching commits for {dataset_id}: {e}", exc_info=True)
|
| 61 |
+
raise # Let the API layer catch and handle this
|
| 62 |
+
result = []
|
| 63 |
+
for c in commits[:limit]:
|
| 64 |
+
try:
|
| 65 |
+
commit_id = getattr(c, "commit_id", "")
|
| 66 |
+
title = getattr(c, "title", "")
|
| 67 |
+
message = getattr(c, "message", title)
|
| 68 |
+
authors = getattr(c, "authors", [])
|
| 69 |
+
author_name = authors[0] if authors and isinstance(authors, list) else ""
|
| 70 |
+
created_at = getattr(c, "created_at", None)
|
| 71 |
+
if created_at:
|
| 72 |
+
if hasattr(created_at, "isoformat"):
|
| 73 |
+
date = created_at.isoformat()
|
| 74 |
+
else:
|
| 75 |
+
date = str(created_at)
|
| 76 |
+
else:
|
| 77 |
+
date = ""
|
| 78 |
+
result.append({
|
| 79 |
+
"id": commit_id or "",
|
| 80 |
+
"title": title or message or "",
|
| 81 |
+
"message": message or title or "",
|
| 82 |
+
"author": {"name": author_name, "email": ""},
|
| 83 |
+
"date": date,
|
| 84 |
+
})
|
| 85 |
+
except Exception as e:
|
| 86 |
+
log.error(f"[get_dataset_commits] Error parsing commit: {e} | Commit: {getattr(c, '__dict__', str(c))}", exc_info=True)
|
| 87 |
+
log.info(f"[get_dataset_commits] Returning {len(result)} parsed commits for {dataset_id}")
|
| 88 |
+
return result
|
| 89 |
+
|
| 90 |
+
def get_dataset_files(dataset_id: str) -> List[str]:
|
| 91 |
+
return api.list_repo_files(repo_id=dataset_id, repo_type="dataset")
|
| 92 |
+
|
| 93 |
+
def get_file_url(dataset_id: str, filename: str, revision: Optional[str] = None) -> str:
|
| 94 |
+
from huggingface_hub import hf_hub_url
|
| 95 |
+
return hf_hub_url(repo_id=dataset_id, filename=filename, repo_type="dataset", revision=revision)
|
| 96 |
+
|
| 97 |
+
def get_datasets_page_from_zset(offset: int = 0, limit: int = 10, search: str = None) -> dict:
|
| 98 |
+
import redis
|
| 99 |
+
import json
|
| 100 |
+
redis_client = redis.Redis(host="redis", port=6379, db=0, decode_responses=True)
|
| 101 |
+
zset_key = "hf:datasets:all:zset"
|
| 102 |
+
hash_key = "hf:datasets:all:hash"
|
| 103 |
+
# Get total count
|
| 104 |
+
total = redis_client.zcard(zset_key)
|
| 105 |
+
# Get dataset IDs for the page
|
| 106 |
+
ids = redis_client.zrange(zset_key, offset, offset + limit - 1)
|
| 107 |
+
# Fetch metadata for those IDs
|
| 108 |
+
if not ids:
|
| 109 |
+
return {"items": [], "count": total}
|
| 110 |
+
items = redis_client.hmget(hash_key, ids)
|
| 111 |
+
# Parse JSON and filter/search if needed
|
| 112 |
+
parsed = []
|
| 113 |
+
for raw in items:
|
| 114 |
+
if not raw:
|
| 115 |
+
continue
|
| 116 |
+
try:
|
| 117 |
+
item = json.loads(raw)
|
| 118 |
+
parsed.append(item)
|
| 119 |
+
except Exception:
|
| 120 |
+
continue
|
| 121 |
+
if search:
|
| 122 |
+
parsed = [d for d in parsed if search.lower() in (d.get("id") or "").lower()]
|
| 123 |
+
return {"items": parsed, "count": total}
|
| 124 |
+
|
| 125 |
+
async def _fetch_size(session: httpx.AsyncClient, dataset_id: str) -> Optional[int]:
|
| 126 |
+
"""Fetch dataset size from the datasets server asynchronously."""
|
| 127 |
+
url = f"https://datasets-server.huggingface.co/size?dataset={dataset_id}"
|
| 128 |
+
try:
|
| 129 |
+
resp = await session.get(url, timeout=30)
|
| 130 |
+
if resp.status_code == 200:
|
| 131 |
+
data = resp.json()
|
| 132 |
+
return data.get("size", {}).get("dataset", {}).get("num_bytes_original_files")
|
| 133 |
+
except Exception as e:
|
| 134 |
+
log.warning(f"Could not fetch size for {dataset_id}: {e}")
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
async def _fetch_sizes(dataset_ids: List[str]) -> Dict[str, Optional[int]]:
|
| 138 |
+
"""Fetch dataset sizes in parallel."""
|
| 139 |
+
results: Dict[str, Optional[int]] = {}
|
| 140 |
+
async with httpx.AsyncClient() as session:
|
| 141 |
+
tasks = {dataset_id: asyncio.create_task(_fetch_size(session, dataset_id)) for dataset_id in dataset_ids}
|
| 142 |
+
for dataset_id, task in tasks.items():
|
| 143 |
+
results[dataset_id] = await task
|
| 144 |
+
return results
|
| 145 |
+
|
| 146 |
+
def process_datasets_page(offset, limit):
|
| 147 |
+
"""
|
| 148 |
+
Fetch and process a single page of datasets from Hugging Face and cache them in Redis.
|
| 149 |
+
"""
|
| 150 |
+
import redis
|
| 151 |
+
import os
|
| 152 |
+
import json
|
| 153 |
+
import asyncio
|
| 154 |
+
log = logging.getLogger(__name__)
|
| 155 |
+
log.info(f"[process_datasets_page] ENTRY: offset={offset}, limit={limit}")
|
| 156 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
| 157 |
+
if not token:
|
| 158 |
+
log.error("[process_datasets_page] HUGGINGFACEHUB_API_TOKEN environment variable is not set.")
|
| 159 |
+
raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
|
| 160 |
+
headers = {
|
| 161 |
+
"Authorization": f"Bearer {token}",
|
| 162 |
+
"User-Agent": "Mozilla/5.0 (compatible; CollinearTool/1.0; +https://yourdomain.com)"
|
| 163 |
+
}
|
| 164 |
+
params = {"limit": limit, "offset": offset, "full": "True"}
|
| 165 |
+
redis_client = redis.Redis(host="redis", port=6379, db=0, decode_responses=True)
|
| 166 |
+
stream_key = "hf:datasets:all:stream"
|
| 167 |
+
zset_key = "hf:datasets:all:zset"
|
| 168 |
+
hash_key = "hf:datasets:all:hash"
|
| 169 |
+
try:
|
| 170 |
+
log.info(f"[process_datasets_page] Requesting {HF_API_URL} with params={params}")
|
| 171 |
+
response = requests.get(HF_API_URL, headers=headers, params=params, timeout=120)
|
| 172 |
+
response.raise_for_status()
|
| 173 |
+
|
| 174 |
+
page_items = response.json()
|
| 175 |
+
|
| 176 |
+
log.info(f"[process_datasets_page] Received {len(page_items)} datasets at offset {offset}")
|
| 177 |
+
|
| 178 |
+
dataset_ids = [ds.get("id") for ds in page_items]
|
| 179 |
+
size_map = asyncio.run(_fetch_sizes(dataset_ids))
|
| 180 |
+
|
| 181 |
+
for ds in page_items:
|
| 182 |
+
dataset_id = ds.get("id")
|
| 183 |
+
size_bytes = size_map.get(dataset_id)
|
| 184 |
+
downloads = ds.get("downloads")
|
| 185 |
+
likes = ds.get("likes")
|
| 186 |
+
impact_level, assessment_method = determine_impact_level_by_criteria(size_bytes, downloads, likes)
|
| 187 |
+
metrics = DatasetMetrics(size_bytes=size_bytes, downloads=downloads, likes=likes)
|
| 188 |
+
thresholds = {
|
| 189 |
+
"size_bytes": {
|
| 190 |
+
"low": str(100 * 1024 * 1024),
|
| 191 |
+
"medium": str(1 * 1024 * 1024 * 1024),
|
| 192 |
+
"high": str(10 * 1024 * 1024 * 1024)
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
impact_assessment = ImpactAssessment(
|
| 196 |
+
dataset_id=dataset_id,
|
| 197 |
+
impact_level=impact_level,
|
| 198 |
+
assessment_method=assessment_method,
|
| 199 |
+
metrics=metrics,
|
| 200 |
+
thresholds=thresholds
|
| 201 |
+
).model_dump()
|
| 202 |
+
item = {
|
| 203 |
+
"id": dataset_id,
|
| 204 |
+
"name": ds.get("name"),
|
| 205 |
+
"description": ds.get("description"),
|
| 206 |
+
"size_bytes": size_bytes,
|
| 207 |
+
"impact_level": impact_level.value if isinstance(impact_level, ImpactLevel) else impact_level,
|
| 208 |
+
"downloads": downloads,
|
| 209 |
+
"likes": likes,
|
| 210 |
+
"tags": ds.get("tags", []),
|
| 211 |
+
"impact_assessment": json.dumps(impact_assessment)
|
| 212 |
+
}
|
| 213 |
+
final_item = {}
|
| 214 |
+
for k, v in item.items():
|
| 215 |
+
if isinstance(v, list) or isinstance(v, dict):
|
| 216 |
+
final_item[k] = json.dumps(v)
|
| 217 |
+
elif v is None:
|
| 218 |
+
final_item[k] = 'null'
|
| 219 |
+
else:
|
| 220 |
+
final_item[k] = str(v)
|
| 221 |
+
|
| 222 |
+
redis_client.xadd(stream_key, final_item)
|
| 223 |
+
redis_client.zadd(zset_key, {dataset_id: offset})
|
| 224 |
+
redis_client.hset(hash_key, dataset_id, json.dumps(item))
|
| 225 |
+
|
| 226 |
+
log.info(f"[process_datasets_page] EXIT: Cached {len(page_items)} datasets at offset {offset}")
|
| 227 |
+
return len(page_items)
|
| 228 |
+
except Exception as exc:
|
| 229 |
+
log.error(f"[process_datasets_page] ERROR: offset={offset}, limit={limit}, exc={exc}", exc_info=True)
|
| 230 |
+
raise
|
| 231 |
+
|
| 232 |
+
def refresh_datasets_cache():
|
| 233 |
+
"""
|
| 234 |
+
Orchestrator: Enqueue Celery tasks to fetch all Hugging Face datasets in parallel.
|
| 235 |
+
Uses direct calls to HF API.
|
| 236 |
+
"""
|
| 237 |
+
import requests
|
| 238 |
+
log.info("[refresh_datasets_cache] Orchestrating dataset fetch tasks using direct HF API calls.")
|
| 239 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
| 240 |
+
if not token:
|
| 241 |
+
log.error("[refresh_datasets_cache] HUGGINGFACEHUB_API_TOKEN environment variable is not set.")
|
| 242 |
+
raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
|
| 243 |
+
|
| 244 |
+
headers = {
|
| 245 |
+
"Authorization": f"Bearer {token}",
|
| 246 |
+
"User-Agent": "Mozilla/5.0 (compatible; CollinearTool/1.0; +https://yourdomain.com)"
|
| 247 |
+
}
|
| 248 |
+
limit = 500
|
| 249 |
+
|
| 250 |
+
params = {"limit": 1, "offset": 0}
|
| 251 |
+
try:
|
| 252 |
+
response = requests.get(HF_API_URL, headers=headers, params=params, timeout=120)
|
| 253 |
+
response.raise_for_status()
|
| 254 |
+
total_str = response.headers.get('X-Total-Count')
|
| 255 |
+
if not total_str:
|
| 256 |
+
log.error("[refresh_datasets_cache] 'X-Total-Count' header not found in HF API response.")
|
| 257 |
+
raise ValueError("'X-Total-Count' header missing from Hugging Face API response.")
|
| 258 |
+
total = int(total_str)
|
| 259 |
+
log.info(f"[refresh_datasets_cache] Total datasets reported by HF API: {total}")
|
| 260 |
+
except requests.RequestException as e:
|
| 261 |
+
log.error(f"[refresh_datasets_cache] Error fetching total dataset count from HF API: {e}")
|
| 262 |
+
raise
|
| 263 |
+
except ValueError as e:
|
| 264 |
+
log.error(f"[refresh_datasets_cache] Error parsing total dataset count: {e}")
|
| 265 |
+
raise
|
| 266 |
+
|
| 267 |
+
num_pages = (total + limit - 1) // limit
|
| 268 |
+
from app.tasks.dataset_tasks import fetch_datasets_page
|
| 269 |
+
from celery import group
|
| 270 |
+
tasks = []
|
| 271 |
+
for page_num in range(num_pages):
|
| 272 |
+
offset = page_num * limit
|
| 273 |
+
tasks.append(fetch_datasets_page.s(offset, limit))
|
| 274 |
+
log.info(f"[refresh_datasets_cache] Scheduled page at offset {offset}, limit {limit}.")
|
| 275 |
+
if tasks:
|
| 276 |
+
group(tasks).apply_async()
|
| 277 |
+
log.info(f"[refresh_datasets_cache] Enqueued {len(tasks)} fetch tasks.")
|
| 278 |
+
else:
|
| 279 |
+
log.warning("[refresh_datasets_cache] No dataset pages found to schedule.")
|
| 280 |
+
|
| 281 |
+
def determine_impact_level_by_criteria(size_bytes, downloads=None, likes=None):
|
| 282 |
+
try:
|
| 283 |
+
size = int(size_bytes) if size_bytes not in (None, 'null') else 0
|
| 284 |
+
except Exception:
|
| 285 |
+
size = 0
|
| 286 |
+
|
| 287 |
+
# Prefer size_bytes if available
|
| 288 |
+
if size >= 10 * 1024 * 1024 * 1024:
|
| 289 |
+
return ("high", "large_size")
|
| 290 |
+
elif size >= 1 * 1024 * 1024 * 1024:
|
| 291 |
+
return ("medium", "medium_size")
|
| 292 |
+
elif size >= 100 * 1024 * 1024:
|
| 293 |
+
return ("low", "small_size")
|
| 294 |
+
# Fallback to downloads if size_bytes is missing or too small
|
| 295 |
+
if downloads is not None:
|
| 296 |
+
try:
|
| 297 |
+
downloads = int(downloads)
|
| 298 |
+
if downloads >= 100000:
|
| 299 |
+
return ("high", "downloads")
|
| 300 |
+
elif downloads >= 10000:
|
| 301 |
+
return ("medium", "downloads")
|
| 302 |
+
elif downloads >= 1000:
|
| 303 |
+
return ("low", "downloads")
|
| 304 |
+
except Exception:
|
| 305 |
+
pass
|
| 306 |
+
# Fallback to likes if downloads is missing
|
| 307 |
+
if likes is not None:
|
| 308 |
+
try:
|
| 309 |
+
likes = int(likes)
|
| 310 |
+
if likes >= 1000:
|
| 311 |
+
return ("high", "likes")
|
| 312 |
+
elif likes >= 100:
|
| 313 |
+
return ("medium", "likes")
|
| 314 |
+
elif likes >= 10:
|
| 315 |
+
return ("low", "likes")
|
| 316 |
+
except Exception:
|
| 317 |
+
pass
|
| 318 |
+
return ("not_available", "size_and_downloads_and_likes_unknown")
|
| 319 |
+
|
| 320 |
+
def get_dataset_size(dataset: dict, dataset_id: str = None):
|
| 321 |
+
"""
|
| 322 |
+
Extract the size in bytes from a dataset dictionary.
|
| 323 |
+
Tries multiple locations based on possible HuggingFace API responses.
|
| 324 |
+
"""
|
| 325 |
+
# Try top-level key
|
| 326 |
+
size_bytes = dataset.get("size_bytes")
|
| 327 |
+
if size_bytes not in (None, 'null'):
|
| 328 |
+
return size_bytes
|
| 329 |
+
# Try nested structure from the size API
|
| 330 |
+
size_bytes = (
|
| 331 |
+
dataset.get("size", {})
|
| 332 |
+
.get("dataset", {})
|
| 333 |
+
.get("num_bytes_original_files")
|
| 334 |
+
)
|
| 335 |
+
if size_bytes not in (None, 'null'):
|
| 336 |
+
return size_bytes
|
| 337 |
+
# Try metrics or info sub-dictionaries if present
|
| 338 |
+
for key in ["metrics", "info"]:
|
| 339 |
+
sub = dataset.get(key, {})
|
| 340 |
+
if isinstance(sub, dict):
|
| 341 |
+
size_bytes = sub.get("size_bytes")
|
| 342 |
+
if size_bytes not in (None, 'null'):
|
| 343 |
+
return size_bytes
|
| 344 |
+
# Not found
|
| 345 |
+
return None
|
| 346 |
+
|
| 347 |
+
async def get_datasets_page_from_zset_async(offset: int = 0, limit: int = 10, search: str = None) -> dict:
|
| 348 |
+
redis_client = aioredis.Redis(host="redis", port=6379, db=0, decode_responses=True)
|
| 349 |
+
zset_key = "hf:datasets:all:zset"
|
| 350 |
+
hash_key = "hf:datasets:all:hash"
|
| 351 |
+
total = await redis_client.zcard(zset_key)
|
| 352 |
+
ids = await redis_client.zrange(zset_key, offset, offset + limit - 1)
|
| 353 |
+
if not ids:
|
| 354 |
+
return {"items": [], "count": total}
|
| 355 |
+
items = await redis_client.hmget(hash_key, ids)
|
| 356 |
+
parsed = []
|
| 357 |
+
for raw in items:
|
| 358 |
+
if not raw:
|
| 359 |
+
continue
|
| 360 |
+
try:
|
| 361 |
+
item = json.loads(raw)
|
| 362 |
+
parsed.append(item)
|
| 363 |
+
except Exception:
|
| 364 |
+
continue
|
| 365 |
+
if search:
|
| 366 |
+
parsed = [d for d in parsed if search.lower() in (d.get("id") or "").lower()]
|
| 367 |
+
return {"items": parsed, "count": total}
|
| 368 |
+
|
| 369 |
+
async def get_dataset_commits_async(dataset_id: str, limit: int = 20):
|
| 370 |
+
from huggingface_hub import HfApi
|
| 371 |
+
import logging
|
| 372 |
+
log = logging.getLogger(__name__)
|
| 373 |
+
api = HfApi()
|
| 374 |
+
log.info(f"[get_dataset_commits_async] Fetching commits for dataset_id={dataset_id}")
|
| 375 |
+
try:
|
| 376 |
+
# huggingface_hub is sync, so run in threadpool
|
| 377 |
+
import anyio
|
| 378 |
+
commits = await anyio.to_thread.run_sync(api.list_repo_commits, repo_id=dataset_id, repo_type="dataset")
|
| 379 |
+
log.info(f"[get_dataset_commits_async] Received {len(commits)} commits for {dataset_id}")
|
| 380 |
+
except Exception as e:
|
| 381 |
+
log.error(f"[get_dataset_commits_async] Error fetching commits for {dataset_id}: {e}", exc_info=True)
|
| 382 |
+
raise
|
| 383 |
+
result = []
|
| 384 |
+
for c in commits[:limit]:
|
| 385 |
+
try:
|
| 386 |
+
commit_id = getattr(c, "commit_id", "")
|
| 387 |
+
title = getattr(c, "title", "")
|
| 388 |
+
message = getattr(c, "message", title)
|
| 389 |
+
authors = getattr(c, "authors", [])
|
| 390 |
+
author_name = authors[0] if authors and isinstance(authors, list) else ""
|
| 391 |
+
created_at = getattr(c, "created_at", None)
|
| 392 |
+
if created_at:
|
| 393 |
+
if hasattr(created_at, "isoformat"):
|
| 394 |
+
date = created_at.isoformat()
|
| 395 |
+
else:
|
| 396 |
+
date = str(created_at)
|
| 397 |
+
else:
|
| 398 |
+
date = ""
|
| 399 |
+
result.append({
|
| 400 |
+
"id": commit_id or "",
|
| 401 |
+
"title": title or message or "",
|
| 402 |
+
"message": message or title or "",
|
| 403 |
+
"author": {"name": author_name, "email": ""},
|
| 404 |
+
"date": date,
|
| 405 |
+
})
|
| 406 |
+
except Exception as e:
|
| 407 |
+
log.error(f"[get_dataset_commits_async] Error parsing commit: {e} | Commit: {getattr(c, '__dict__', str(c))}", exc_info=True)
|
| 408 |
+
log.info(f"[get_dataset_commits_async] Returning {len(result)} parsed commits for {dataset_id}")
|
| 409 |
+
return result
|
| 410 |
+
|
| 411 |
+
async def get_dataset_files_async(dataset_id: str) -> List[str]:
|
| 412 |
+
from huggingface_hub import HfApi
|
| 413 |
+
import anyio
|
| 414 |
+
api = HfApi()
|
| 415 |
+
# huggingface_hub is sync, so run in threadpool
|
| 416 |
+
return await anyio.to_thread.run_sync(api.list_repo_files, repo_id=dataset_id, repo_type="dataset")
|
| 417 |
+
|
| 418 |
+
async def get_file_url_async(dataset_id: str, filename: str, revision: Optional[str] = None) -> str:
|
| 419 |
+
from huggingface_hub import hf_hub_url
|
| 420 |
+
import anyio
|
| 421 |
+
# huggingface_hub is sync, so run in threadpool
|
| 422 |
+
return await anyio.to_thread.run_sync(hf_hub_url, repo_id=dataset_id, filename=filename, repo_type="dataset", revision=revision)
|
| 423 |
+
|
| 424 |
+
# Fetch and cache all datasets
|
| 425 |
+
|
| 426 |
+
class EnhancedJSONEncoder(json.JSONEncoder):
|
| 427 |
+
def default(self, obj):
|
| 428 |
+
if isinstance(obj, datetime):
|
| 429 |
+
return obj.isoformat()
|
| 430 |
+
return super().default(obj)
|
| 431 |
+
|
| 432 |
+
async def fetch_size(session, dataset_id, token=None):
|
| 433 |
+
url = f"https://datasets-server.huggingface.co/size?dataset={dataset_id}"
|
| 434 |
+
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
| 435 |
+
try:
|
| 436 |
+
resp = await session.get(url, headers=headers, timeout=30)
|
| 437 |
+
if resp.status_code == 200:
|
| 438 |
+
data = resp.json()
|
| 439 |
+
return dataset_id, data.get("size", {}).get("dataset", {}).get("num_bytes_original_files")
|
| 440 |
+
except Exception as e:
|
| 441 |
+
log.warning(f"Could not fetch size for {dataset_id}: {e}")
|
| 442 |
+
return dataset_id, None
|
| 443 |
+
|
| 444 |
+
async def fetch_all_sizes(dataset_ids, token=None, batch_size=50):
|
| 445 |
+
results = {}
|
| 446 |
+
async with httpx.AsyncClient() as session:
|
| 447 |
+
for i in range(0, len(dataset_ids), batch_size):
|
| 448 |
+
batch = dataset_ids[i:i+batch_size]
|
| 449 |
+
tasks = [fetch_size(session, ds_id, token) for ds_id in batch]
|
| 450 |
+
batch_results = await asyncio.gather(*tasks)
|
| 451 |
+
for ds_id, size in batch_results:
|
| 452 |
+
results[ds_id] = size
|
| 453 |
+
return results
|
| 454 |
+
|
| 455 |
+
def fetch_and_cache_all_datasets(token: str):
|
| 456 |
+
api = HfApi(token=token)
|
| 457 |
+
log.info("Fetching all datasets from Hugging Face Hub...")
|
| 458 |
+
all_datasets = list(api.list_datasets())
|
| 459 |
+
all_datasets_dicts = []
|
| 460 |
+
dataset_ids = [d.id for d in all_datasets]
|
| 461 |
+
# Fetch all sizes in batches
|
| 462 |
+
sizes = asyncio.run(fetch_all_sizes(dataset_ids, token=token, batch_size=50))
|
| 463 |
+
for d in all_datasets:
|
| 464 |
+
data = d.__dict__
|
| 465 |
+
size_bytes = sizes.get(d.id)
|
| 466 |
+
downloads = data.get("downloads")
|
| 467 |
+
likes = data.get("likes")
|
| 468 |
+
data["size_bytes"] = size_bytes
|
| 469 |
+
impact_level, _ = determine_impact_level_by_criteria(size_bytes, downloads, likes)
|
| 470 |
+
data["impact_level"] = impact_level
|
| 471 |
+
all_datasets_dicts.append(data)
|
| 472 |
+
compressed = gzip.compress(json.dumps(all_datasets_dicts, cls=EnhancedJSONEncoder).encode("utf-8"))
|
| 473 |
+
r = redis.Redis(host="redis", port=6379, decode_responses=False)
|
| 474 |
+
r.set(REDIS_KEY, compressed)
|
| 475 |
+
log.info(f"Cached {len(all_datasets_dicts)} datasets in Redis under {REDIS_KEY}")
|
| 476 |
+
return len(all_datasets_dicts)
|
| 477 |
+
|
| 478 |
+
# Native pagination from cache
|
| 479 |
+
|
| 480 |
+
def get_datasets_page_from_cache(limit: int, offset: int):
|
| 481 |
+
r = redis.Redis(host="redis", port=6379, decode_responses=False)
|
| 482 |
+
compressed = r.get(REDIS_KEY)
|
| 483 |
+
if not compressed:
|
| 484 |
+
return {"error": "Cache not found. Please refresh datasets."}, 404
|
| 485 |
+
all_datasets = json.loads(gzip.decompress(compressed).decode("utf-8"))
|
| 486 |
+
total = len(all_datasets)
|
| 487 |
+
if offset < 0 or offset >= total:
|
| 488 |
+
return {"error": "Offset out of range.", "total": total}, 400
|
| 489 |
+
page = all_datasets[offset:offset+limit]
|
| 490 |
+
total_pages = (total + limit - 1) // limit
|
| 491 |
+
current_page = (offset // limit) + 1
|
| 492 |
+
next_page = current_page + 1 if offset + limit < total else None
|
| 493 |
+
prev_page = current_page - 1 if current_page > 1 else None
|
| 494 |
+
return {
|
| 495 |
+
"total": total,
|
| 496 |
+
"current_page": current_page,
|
| 497 |
+
"total_pages": total_pages,
|
| 498 |
+
"next_page": next_page,
|
| 499 |
+
"prev_page": prev_page,
|
| 500 |
+
"items": page
|
| 501 |
+
}, 200
|
app/services/redis_client.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Redis client for caching and task queue management."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import Any, Dict, Optional, TypeVar
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import logging
|
| 7 |
+
from time import time as _time
|
| 8 |
+
|
| 9 |
+
import redis.asyncio as redis_async
|
| 10 |
+
import redis as redis_sync # Import synchronous Redis client
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 13 |
+
|
| 14 |
+
from app.core.config import settings
|
| 15 |
+
|
| 16 |
+
# Type variable for cache
|
| 17 |
+
T = TypeVar('T')
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
# Redis connection pools for reusing connections
|
| 23 |
+
_redis_pool_async = None
|
| 24 |
+
_redis_pool_sync = None # Synchronous pool
|
| 25 |
+
|
| 26 |
+
# Default cache expiration (12 hours)
|
| 27 |
+
DEFAULT_CACHE_EXPIRY = 60 * 60 * 12
|
| 28 |
+
|
| 29 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10))
|
| 30 |
+
async def get_redis_pool() -> redis_async.Redis:
|
| 31 |
+
"""Get or create async Redis connection pool with retry logic."""
|
| 32 |
+
global _redis_pool_async
|
| 33 |
+
|
| 34 |
+
if _redis_pool_async is None:
|
| 35 |
+
# Get Redis configuration from settings
|
| 36 |
+
redis_url = settings.REDIS_URL or "redis://localhost:6379/0"
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# Create connection pool with reasonable defaults
|
| 40 |
+
_redis_pool_async = redis_async.ConnectionPool.from_url(
|
| 41 |
+
redis_url,
|
| 42 |
+
max_connections=10,
|
| 43 |
+
decode_responses=True,
|
| 44 |
+
health_check_interval=5,
|
| 45 |
+
socket_connect_timeout=5,
|
| 46 |
+
socket_keepalive=True,
|
| 47 |
+
retry_on_timeout=True
|
| 48 |
+
)
|
| 49 |
+
log.info(f"Created async Redis connection pool with URL: {redis_url}")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
log.error(f"Error creating async Redis connection pool: {e}")
|
| 52 |
+
raise
|
| 53 |
+
|
| 54 |
+
return redis_async.Redis(connection_pool=_redis_pool_async)
|
| 55 |
+
|
| 56 |
+
def get_redis_pool_sync() -> redis_sync.Redis:
|
| 57 |
+
"""Get or create synchronous Redis connection pool."""
|
| 58 |
+
global _redis_pool_sync
|
| 59 |
+
|
| 60 |
+
if _redis_pool_sync is None:
|
| 61 |
+
# Get Redis configuration from settings
|
| 62 |
+
redis_url = settings.REDIS_URL or "redis://localhost:6379/0"
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
# Create connection pool with reasonable defaults
|
| 66 |
+
_redis_pool_sync = redis_sync.ConnectionPool.from_url(
|
| 67 |
+
redis_url,
|
| 68 |
+
max_connections=10,
|
| 69 |
+
decode_responses=True,
|
| 70 |
+
socket_connect_timeout=5,
|
| 71 |
+
socket_keepalive=True,
|
| 72 |
+
retry_on_timeout=True
|
| 73 |
+
)
|
| 74 |
+
log.info(f"Created sync Redis connection pool with URL: {redis_url}")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
log.error(f"Error creating sync Redis connection pool: {e}")
|
| 77 |
+
raise
|
| 78 |
+
|
| 79 |
+
return redis_sync.Redis(connection_pool=_redis_pool_sync)
|
| 80 |
+
|
| 81 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=5))
|
| 82 |
+
async def get_redis() -> redis_async.Redis:
|
| 83 |
+
"""Get Redis client from pool with retry logic."""
|
| 84 |
+
try:
|
| 85 |
+
redis_client = await get_redis_pool()
|
| 86 |
+
return redis_client
|
| 87 |
+
except Exception as e:
|
| 88 |
+
log.error(f"Error getting Redis client: {e}")
|
| 89 |
+
raise
|
| 90 |
+
|
| 91 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=5))
|
| 92 |
+
def get_redis_sync() -> redis_sync.Redis:
|
| 93 |
+
"""Get synchronous Redis client from pool with retry logic."""
|
| 94 |
+
try:
|
| 95 |
+
return get_redis_pool_sync()
|
| 96 |
+
except Exception as e:
|
| 97 |
+
log.error(f"Error getting synchronous Redis client: {e}")
|
| 98 |
+
raise
|
| 99 |
+
|
| 100 |
+
# Cache key generation
|
| 101 |
+
def generate_cache_key(prefix: str, *args: Any) -> str:
|
| 102 |
+
"""Generate cache key with prefix and args."""
|
| 103 |
+
key_parts = [prefix] + [str(arg) for arg in args if arg]
|
| 104 |
+
return ":".join(key_parts)
|
| 105 |
+
|
| 106 |
+
# JSON serialization helpers
|
| 107 |
+
def _json_serialize(obj: Any) -> str:
|
| 108 |
+
"""Serialize object to JSON with datetime support."""
|
| 109 |
+
def _serialize_datetime(o: Any) -> str:
|
| 110 |
+
if isinstance(o, datetime):
|
| 111 |
+
return o.isoformat()
|
| 112 |
+
if isinstance(o, BaseModel):
|
| 113 |
+
return o.dict()
|
| 114 |
+
return str(o)
|
| 115 |
+
|
| 116 |
+
return json.dumps(obj, default=_serialize_datetime)
|
| 117 |
+
|
| 118 |
+
def _json_deserialize(data: str, model_class: Optional[type] = None) -> Any:
|
| 119 |
+
"""Deserialize JSON string to object with datetime support."""
|
| 120 |
+
result = json.loads(data)
|
| 121 |
+
|
| 122 |
+
if model_class and issubclass(model_class, BaseModel):
|
| 123 |
+
return model_class.parse_obj(result)
|
| 124 |
+
|
| 125 |
+
return result
|
| 126 |
+
|
| 127 |
+
# Async cache operations
|
| 128 |
+
async def cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool:
|
| 129 |
+
"""Set cache value with expiration (async version)."""
|
| 130 |
+
redis_client = await get_redis()
|
| 131 |
+
serialized = _json_serialize(value)
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
await redis_client.set(key, serialized, ex=expire)
|
| 135 |
+
log.debug(f"Cached data at key: {key}, expires in {expire}s")
|
| 136 |
+
return True
|
| 137 |
+
except Exception as e:
|
| 138 |
+
log.error(f"Error caching data at key {key}: {e}")
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
async def cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]:
|
| 142 |
+
"""Get cache value with optional model deserialization (async version)."""
|
| 143 |
+
redis_client = await get_redis()
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
data = await redis_client.get(key)
|
| 147 |
+
if not data:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
log.debug(f"Cache hit for key: {key}")
|
| 151 |
+
return _json_deserialize(data, model_class)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
log.error(f"Error retrieving cache for key {key}: {e}")
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
# Synchronous cache operations for Celery tasks
|
| 157 |
+
def sync_cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool:
|
| 158 |
+
"""Set cache value with expiration (synchronous version for Celery tasks). Logs slow operations."""
|
| 159 |
+
redis_client = get_redis_sync()
|
| 160 |
+
serialized = _json_serialize(value)
|
| 161 |
+
start = _time()
|
| 162 |
+
try:
|
| 163 |
+
redis_client.set(key, serialized, ex=expire)
|
| 164 |
+
elapsed = _time() - start
|
| 165 |
+
if elapsed > 2:
|
| 166 |
+
log.warning(f"Slow sync_cache_set for key {key}: {elapsed:.2f}s")
|
| 167 |
+
log.debug(f"Cached data at key: {key}, expires in {expire}s (sync)")
|
| 168 |
+
return True
|
| 169 |
+
except Exception as e:
|
| 170 |
+
log.error(f"Error caching data at key {key}: {e}")
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
def sync_cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]:
|
| 174 |
+
"""Get cache value with optional model deserialization (synchronous version for Celery tasks). Logs slow operations."""
|
| 175 |
+
redis_client = get_redis_sync()
|
| 176 |
+
start = _time()
|
| 177 |
+
try:
|
| 178 |
+
data = redis_client.get(key)
|
| 179 |
+
elapsed = _time() - start
|
| 180 |
+
if elapsed > 2:
|
| 181 |
+
log.warning(f"Slow sync_cache_get for key {key}: {elapsed:.2f}s")
|
| 182 |
+
if not data:
|
| 183 |
+
return None
|
| 184 |
+
log.debug(f"Cache hit for key: {key} (sync)")
|
| 185 |
+
return _json_deserialize(data, model_class)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
log.error(f"Error retrieving cache for key {key}: {e}")
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
async def cache_invalidate(key: str) -> bool:
|
| 191 |
+
"""Invalidate cache for key."""
|
| 192 |
+
redis_client = await get_redis()
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
await redis_client.delete(key)
|
| 196 |
+
log.debug(f"Invalidated cache for key: {key}")
|
| 197 |
+
return True
|
| 198 |
+
except Exception as e:
|
| 199 |
+
log.error(f"Error invalidating cache for key {key}: {e}")
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
async def cache_invalidate_pattern(pattern: str) -> int:
|
| 203 |
+
"""Invalidate all cache keys matching pattern."""
|
| 204 |
+
redis_client = await get_redis()
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
keys = await redis_client.keys(pattern)
|
| 208 |
+
if not keys:
|
| 209 |
+
return 0
|
| 210 |
+
|
| 211 |
+
count = await redis_client.delete(*keys)
|
| 212 |
+
log.debug(f"Invalidated {count} keys matching pattern: {pattern}")
|
| 213 |
+
return count
|
| 214 |
+
except Exception as e:
|
| 215 |
+
log.error(f"Error invalidating keys with pattern {pattern}: {e}")
|
| 216 |
+
return 0
|
| 217 |
+
|
| 218 |
+
# Task queue operations
|
| 219 |
+
async def enqueue_task(queue_name: str, task_id: str, payload: Dict[str, Any]) -> bool:
|
| 220 |
+
"""Add task to queue."""
|
| 221 |
+
redis_client = await get_redis()
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
serialized = _json_serialize(payload)
|
| 225 |
+
await redis_client.lpush(f"queue:{queue_name}", serialized)
|
| 226 |
+
await redis_client.hset(f"tasks:{queue_name}", task_id, "pending")
|
| 227 |
+
log.info(f"Enqueued task {task_id} to queue {queue_name}")
|
| 228 |
+
return True
|
| 229 |
+
except Exception as e:
|
| 230 |
+
log.error(f"Error enqueueing task {task_id} to {queue_name}: {e}")
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
async def mark_task_complete(queue_name: str, task_id: str, result: Optional[Dict[str, Any]] = None) -> bool:
|
| 234 |
+
"""Mark task as complete with optional result."""
|
| 235 |
+
redis_client = await get_redis()
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
# Store result if provided
|
| 239 |
+
if result:
|
| 240 |
+
await redis_client.hset(
|
| 241 |
+
f"results:{queue_name}",
|
| 242 |
+
task_id,
|
| 243 |
+
_json_serialize(result)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Mark task as complete
|
| 247 |
+
await redis_client.hset(f"tasks:{queue_name}", task_id, "complete")
|
| 248 |
+
await redis_client.expire(f"tasks:{queue_name}", 86400) # Expire after 24 hours
|
| 249 |
+
|
| 250 |
+
log.info(f"Marked task {task_id} as complete in queue {queue_name}")
|
| 251 |
+
return True
|
| 252 |
+
except Exception as e:
|
| 253 |
+
log.error(f"Error marking task {task_id} as complete: {e}")
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
async def get_task_status(queue_name: str, task_id: str) -> Optional[str]:
|
| 257 |
+
"""Get status of a task."""
|
| 258 |
+
redis_client = await get_redis()
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
status = await redis_client.hget(f"tasks:{queue_name}", task_id)
|
| 262 |
+
return status
|
| 263 |
+
except Exception as e:
|
| 264 |
+
log.error(f"Error getting status for task {task_id}: {e}")
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
async def get_task_result(queue_name: str, task_id: str) -> Optional[Dict[str, Any]]:
|
| 268 |
+
"""Get result of a completed task."""
|
| 269 |
+
redis_client = await get_redis()
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
data = await redis_client.hget(f"results:{queue_name}", task_id)
|
| 273 |
+
if not data:
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
return _json_deserialize(data)
|
| 277 |
+
except Exception as e:
|
| 278 |
+
log.error(f"Error getting result for task {task_id}: {e}")
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
# Stream processing for real-time updates
|
| 282 |
+
async def add_to_stream(stream: str, data: Dict[str, Any], max_len: int = 1000) -> str:
|
| 283 |
+
"""Add event to Redis stream."""
|
| 284 |
+
redis_client = await get_redis()
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
# Convert dict values to strings (Redis streams requirement)
|
| 288 |
+
entry = {k: _json_serialize(v) for k, v in data.items()}
|
| 289 |
+
|
| 290 |
+
# Add to stream with automatic ID generation
|
| 291 |
+
event_id = await redis_client.xadd(
|
| 292 |
+
stream,
|
| 293 |
+
entry,
|
| 294 |
+
maxlen=max_len,
|
| 295 |
+
approximate=True
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
log.debug(f"Added event {event_id} to stream {stream}")
|
| 299 |
+
return event_id
|
| 300 |
+
except Exception as e:
|
| 301 |
+
log.error(f"Error adding to stream {stream}: {e}")
|
| 302 |
+
raise
|
app/tasks/dataset_tasks.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
import asyncio
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
from typing import Dict, List, Any, Optional, Tuple
|
| 6 |
+
from celery import Task, shared_task
|
| 7 |
+
from app.core.celery_app import get_celery_app
|
| 8 |
+
from app.services.hf_datasets import (
|
| 9 |
+
determine_impact_level_by_criteria,
|
| 10 |
+
get_hf_token,
|
| 11 |
+
get_dataset_size,
|
| 12 |
+
refresh_datasets_cache,
|
| 13 |
+
fetch_and_cache_all_datasets,
|
| 14 |
+
)
|
| 15 |
+
from app.services.redis_client import sync_cache_set, sync_cache_get, generate_cache_key
|
| 16 |
+
from app.core.config import settings
|
| 17 |
+
import requests
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Get Celery app instance
|
| 24 |
+
celery_app = get_celery_app()
|
| 25 |
+
|
| 26 |
+
# Constants
|
| 27 |
+
DATASET_CACHE_TTL = 60 * 60 * 24 * 30 # 30 days
|
| 28 |
+
BATCH_PROGRESS_CACHE_TTL = 60 * 60 * 24 * 7 # 7 days for batch progress
|
| 29 |
+
DATASET_SIZE_CACHE_TTL = 60 * 60 * 24 * 30 # 30 days for size info
|
| 30 |
+
|
| 31 |
+
@celery_app.task(name="app.tasks.dataset_tasks.refresh_hf_datasets_cache")
|
| 32 |
+
def refresh_hf_datasets_cache():
|
| 33 |
+
"""Celery task to refresh the HuggingFace datasets cache in Redis."""
|
| 34 |
+
logger.info("Starting refresh of HuggingFace datasets cache via Celery task.")
|
| 35 |
+
try:
|
| 36 |
+
refresh_datasets_cache()
|
| 37 |
+
logger.info("Successfully refreshed HuggingFace datasets cache.")
|
| 38 |
+
return {"status": "success"}
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Failed to refresh HuggingFace datasets cache: {e}")
|
| 41 |
+
return {"status": "error", "error": str(e)}
|
| 42 |
+
|
| 43 |
+
@shared_task(bind=True, max_retries=3, default_retry_delay=10)
|
| 44 |
+
def fetch_datasets_page(self, offset, limit):
|
| 45 |
+
"""
|
| 46 |
+
Celery task to fetch and cache a single page of datasets from Hugging Face.
|
| 47 |
+
Retries on failure.
|
| 48 |
+
"""
|
| 49 |
+
logger.info(f"[fetch_datasets_page] ENTRY: offset={offset}, limit={limit}")
|
| 50 |
+
try:
|
| 51 |
+
from app.services.hf_datasets import process_datasets_page
|
| 52 |
+
logger.info(f"[fetch_datasets_page] Calling process_datasets_page with offset={offset}, limit={limit}")
|
| 53 |
+
result = process_datasets_page(offset, limit)
|
| 54 |
+
logger.info(f"[fetch_datasets_page] SUCCESS: offset={offset}, limit={limit}, result={result}")
|
| 55 |
+
return result
|
| 56 |
+
except Exception as exc:
|
| 57 |
+
logger.error(f"[fetch_datasets_page] ERROR: offset={offset}, limit={limit}, exc={exc}", exc_info=True)
|
| 58 |
+
raise self.retry(exc=exc)
|
| 59 |
+
|
| 60 |
+
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
| 61 |
+
def refresh_hf_datasets_full_cache(self):
|
| 62 |
+
logger.info("[refresh_hf_datasets_full_cache] Starting full Hugging Face datasets cache refresh.")
|
| 63 |
+
try:
|
| 64 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
| 65 |
+
if not token:
|
| 66 |
+
logger.error("[refresh_hf_datasets_full_cache] HUGGINGFACEHUB_API_TOKEN not set.")
|
| 67 |
+
return {"status": "error", "error": "HUGGINGFACEHUB_API_TOKEN not set"}
|
| 68 |
+
count = fetch_and_cache_all_datasets(token)
|
| 69 |
+
logger.info(f"[refresh_hf_datasets_full_cache] Cached {count} datasets.")
|
| 70 |
+
return {"status": "ok", "cached": count}
|
| 71 |
+
except Exception as exc:
|
| 72 |
+
logger.error(f"[refresh_hf_datasets_full_cache] ERROR: {exc}", exc_info=True)
|
| 73 |
+
raise self.retry(exc=exc)
|
migrations/20250620000000_create_combined_datasets_table.sql
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- Create combined_datasets table
|
| 2 |
+
CREATE TABLE IF NOT EXISTS public.combined_datasets (
|
| 3 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
| 4 |
+
name TEXT NOT NULL,
|
| 5 |
+
description TEXT,
|
| 6 |
+
source_datasets TEXT[] NOT NULL,
|
| 7 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
| 8 |
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
| 9 |
+
created_by UUID REFERENCES auth.users(id),
|
| 10 |
+
impact_level TEXT CHECK (impact_level = ANY (ARRAY['low', 'medium', 'high']::text[])),
|
| 11 |
+
status TEXT NOT NULL DEFAULT 'processing',
|
| 12 |
+
combination_strategy TEXT NOT NULL DEFAULT 'merge',
|
| 13 |
+
size_bytes BIGINT,
|
| 14 |
+
file_count INTEGER,
|
| 15 |
+
downloads INTEGER,
|
| 16 |
+
likes INTEGER
|
| 17 |
+
);
|
| 18 |
+
|
| 19 |
+
-- Add indexes for faster querying
|
| 20 |
+
CREATE INDEX IF NOT EXISTS idx_combined_datasets_created_by ON public.combined_datasets(created_by);
|
| 21 |
+
CREATE INDEX IF NOT EXISTS idx_combined_datasets_impact_level ON public.combined_datasets(impact_level);
|
| 22 |
+
CREATE INDEX IF NOT EXISTS idx_combined_datasets_status ON public.combined_datasets(status);
|
| 23 |
+
|
| 24 |
+
-- Add Row Level Security (RLS) policies
|
| 25 |
+
ALTER TABLE public.combined_datasets ENABLE ROW LEVEL SECURITY;
|
| 26 |
+
|
| 27 |
+
-- Policy to allow users to see all combined datasets
|
| 28 |
+
CREATE POLICY "Anyone can view combined datasets"
|
| 29 |
+
ON public.combined_datasets
|
| 30 |
+
FOR SELECT USING (true);
|
| 31 |
+
|
| 32 |
+
-- Policy to allow users to create their own combined datasets
|
| 33 |
+
CREATE POLICY "Users can create their own combined datasets"
|
| 34 |
+
ON public.combined_datasets
|
| 35 |
+
FOR INSERT
|
| 36 |
+
WITH CHECK (auth.uid() = created_by);
|
| 37 |
+
|
| 38 |
+
-- Policy to allow users to update only their own combined datasets
|
| 39 |
+
CREATE POLICY "Users can update their own combined datasets"
|
| 40 |
+
ON public.combined_datasets
|
| 41 |
+
FOR UPDATE
|
| 42 |
+
USING (auth.uid() = created_by);
|
| 43 |
+
|
| 44 |
+
-- Function to automatically update updated_at timestamp
|
| 45 |
+
CREATE OR REPLACE FUNCTION update_combined_datasets_updated_at()
|
| 46 |
+
RETURNS TRIGGER AS $$
|
| 47 |
+
BEGIN
|
| 48 |
+
NEW.updated_at = now();
|
| 49 |
+
RETURN NEW;
|
| 50 |
+
END;
|
| 51 |
+
$$ LANGUAGE plpgsql;
|
| 52 |
+
|
| 53 |
+
-- Trigger to automatically update updated_at timestamp
|
| 54 |
+
CREATE TRIGGER update_combined_datasets_updated_at_trigger
|
| 55 |
+
BEFORE UPDATE ON public.combined_datasets
|
| 56 |
+
FOR EACH ROW
|
| 57 |
+
EXECUTE FUNCTION update_combined_datasets_updated_at();
|
setup.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="collinear-tool",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
include_package_data=True,
|
| 8 |
+
)
|
tests/test_datasets.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pytest
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
BASE_URL = os.environ.get("BASE_URL", "http://127.0.0.1:8000/api")
|
| 6 |
+
|
| 7 |
+
# --- /datasets ---
|
| 8 |
+
def test_list_datasets_http():
|
| 9 |
+
resp = requests.get(f"{BASE_URL}/datasets")
|
| 10 |
+
assert resp.status_code == 200
|
| 11 |
+
data = resp.json()
|
| 12 |
+
assert "items" in data
|
| 13 |
+
assert "total" in data
|
| 14 |
+
assert "warming_up" in data
|
| 15 |
+
|
| 16 |
+
def test_list_datasets_offset_limit_http():
|
| 17 |
+
resp = requests.get(f"{BASE_URL}/datasets?offset=0&limit=3")
|
| 18 |
+
assert resp.status_code == 200
|
| 19 |
+
data = resp.json()
|
| 20 |
+
assert isinstance(data["items"], list)
|
| 21 |
+
assert len(data["items"]) <= 3
|
| 22 |
+
|
| 23 |
+
def test_list_datasets_large_offset_http():
|
| 24 |
+
resp = requests.get(f"{BASE_URL}/datasets?offset=99999&limit=2")
|
| 25 |
+
assert resp.status_code == 200
|
| 26 |
+
data = resp.json()
|
| 27 |
+
assert data["items"] == []
|
| 28 |
+
assert "warming_up" in data
|
| 29 |
+
|
| 30 |
+
def test_list_datasets_invalid_limit_http():
|
| 31 |
+
resp = requests.get(f"{BASE_URL}/datasets?limit=-5")
|
| 32 |
+
assert resp.status_code == 422
|
| 33 |
+
|
| 34 |
+
# --- /datasets/cache-status ---
|
| 35 |
+
def test_cache_status_http():
|
| 36 |
+
resp = requests.get(f"{BASE_URL}/datasets/cache-status")
|
| 37 |
+
assert resp.status_code == 200
|
| 38 |
+
data = resp.json()
|
| 39 |
+
assert "warming_up" in data
|
| 40 |
+
assert "total_items" in data
|
| 41 |
+
assert "last_update" in data
|
| 42 |
+
|
| 43 |
+
# --- /datasets/{dataset_id}/commits ---
|
| 44 |
+
def test_commits_valid_http():
|
| 45 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/commits")
|
| 46 |
+
assert resp.status_code in (200, 404)
|
| 47 |
+
if resp.status_code == 200:
|
| 48 |
+
assert isinstance(resp.json(), list)
|
| 49 |
+
|
| 50 |
+
def test_commits_invalid_http():
|
| 51 |
+
resp = requests.get(f"{BASE_URL}/datasets/invalid-dataset-id/commits")
|
| 52 |
+
assert resp.status_code in (404, 422)
|
| 53 |
+
|
| 54 |
+
# --- /datasets/{dataset_id}/files ---
|
| 55 |
+
def test_files_valid_http():
|
| 56 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/files")
|
| 57 |
+
assert resp.status_code in (200, 404)
|
| 58 |
+
if resp.status_code == 200:
|
| 59 |
+
assert isinstance(resp.json(), list)
|
| 60 |
+
|
| 61 |
+
def test_files_invalid_http():
|
| 62 |
+
resp = requests.get(f"{BASE_URL}/datasets/invalid-dataset-id/files")
|
| 63 |
+
assert resp.status_code in (404, 422)
|
| 64 |
+
|
| 65 |
+
# --- /datasets/{dataset_id}/file-url ---
|
| 66 |
+
def test_file_url_valid_http():
|
| 67 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "README.md"})
|
| 68 |
+
assert resp.status_code in (200, 404)
|
| 69 |
+
if resp.status_code == 200:
|
| 70 |
+
assert "download_url" in resp.json()
|
| 71 |
+
|
| 72 |
+
def test_file_url_invalid_file_http():
|
| 73 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "not_a_real_file.txt"})
|
| 74 |
+
assert resp.status_code in (404, 200)
|
| 75 |
+
|
| 76 |
+
def test_file_url_missing_filename_http():
|
| 77 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url")
|
| 78 |
+
assert resp.status_code in (404, 422)
|
tests/test_datasets_api.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
from app.main import app
|
| 4 |
+
|
| 5 |
+
client = TestClient(app)
|
| 6 |
+
|
| 7 |
+
# --- /api/datasets ---
|
| 8 |
+
def test_list_datasets_default():
|
| 9 |
+
resp = client.get("/api/datasets")
|
| 10 |
+
assert resp.status_code == 200
|
| 11 |
+
data = resp.json()
|
| 12 |
+
assert "items" in data
|
| 13 |
+
assert isinstance(data["items"], list)
|
| 14 |
+
assert "total" in data
|
| 15 |
+
assert "warming_up" in data
|
| 16 |
+
|
| 17 |
+
def test_list_datasets_offset_limit():
|
| 18 |
+
resp = client.get("/api/datasets?offset=0&limit=2")
|
| 19 |
+
assert resp.status_code == 200
|
| 20 |
+
data = resp.json()
|
| 21 |
+
assert isinstance(data["items"], list)
|
| 22 |
+
assert len(data["items"]) <= 2
|
| 23 |
+
|
| 24 |
+
def test_list_datasets_large_offset():
|
| 25 |
+
resp = client.get("/api/datasets?offset=100000&limit=2")
|
| 26 |
+
assert resp.status_code == 200
|
| 27 |
+
data = resp.json()
|
| 28 |
+
assert data["items"] == []
|
| 29 |
+
assert data["warming_up"] in (True, False)
|
| 30 |
+
|
| 31 |
+
def test_list_datasets_negative_limit():
|
| 32 |
+
resp = client.get("/api/datasets?limit=-1")
|
| 33 |
+
assert resp.status_code == 422
|
| 34 |
+
|
| 35 |
+
def test_list_datasets_missing_params():
|
| 36 |
+
resp = client.get("/api/datasets")
|
| 37 |
+
assert resp.status_code == 200
|
| 38 |
+
data = resp.json()
|
| 39 |
+
assert "items" in data
|
| 40 |
+
assert "total" in data
|
| 41 |
+
assert "warming_up" in data
|
| 42 |
+
|
| 43 |
+
# --- /api/datasets/cache-status ---
|
| 44 |
+
def test_cache_status():
|
| 45 |
+
resp = client.get("/api/datasets/cache-status")
|
| 46 |
+
assert resp.status_code == 200
|
| 47 |
+
data = resp.json()
|
| 48 |
+
assert "warming_up" in data
|
| 49 |
+
assert "total_items" in data
|
| 50 |
+
assert "last_update" in data
|
| 51 |
+
|
| 52 |
+
# --- /api/datasets/{dataset_id}/commits ---
|
| 53 |
+
def test_get_commits_valid():
|
| 54 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/commits")
|
| 55 |
+
# Accept 200 (found) or 404 (not found)
|
| 56 |
+
assert resp.status_code in (200, 404)
|
| 57 |
+
if resp.status_code == 200:
|
| 58 |
+
assert isinstance(resp.json(), list)
|
| 59 |
+
|
| 60 |
+
def test_get_commits_invalid():
|
| 61 |
+
resp = client.get("/api/datasets/invalid-dataset-id/commits")
|
| 62 |
+
assert resp.status_code in (404, 422)
|
| 63 |
+
|
| 64 |
+
# --- /api/datasets/{dataset_id}/files ---
|
| 65 |
+
def test_list_files_valid():
|
| 66 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/files")
|
| 67 |
+
assert resp.status_code in (200, 404)
|
| 68 |
+
if resp.status_code == 200:
|
| 69 |
+
assert isinstance(resp.json(), list)
|
| 70 |
+
|
| 71 |
+
def test_list_files_invalid():
|
| 72 |
+
resp = client.get("/api/datasets/invalid-dataset-id/files")
|
| 73 |
+
assert resp.status_code in (404, 422)
|
| 74 |
+
|
| 75 |
+
# --- /api/datasets/{dataset_id}/file-url ---
|
| 76 |
+
def test_get_file_url_valid():
|
| 77 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "README.md"})
|
| 78 |
+
assert resp.status_code in (200, 404)
|
| 79 |
+
if resp.status_code == 200:
|
| 80 |
+
assert "download_url" in resp.json()
|
| 81 |
+
|
| 82 |
+
def test_get_file_url_invalid_file():
|
| 83 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "not_a_real_file.txt"})
|
| 84 |
+
assert resp.status_code in (404, 200)
|
| 85 |
+
|
| 86 |
+
def test_get_file_url_missing_filename():
|
| 87 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url")
|
| 88 |
+
assert resp.status_code in (404, 422)
|