Spaces:
Runtime error
Runtime error
| import os | |
| import copy | |
| import datasets | |
| import pandas as pd | |
| from datasets import Dataset | |
| from collections import defaultdict | |
| from datetime import datetime, timedelta | |
| from background import process_arxiv_ids | |
| from utils import create_hf_hub | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| def _count_nans(row): | |
| count = 0 | |
| for _, (k, v) in enumerate(row.items()): | |
| if v is None: | |
| count = count + 1 | |
| return count | |
| def _initialize_requested_arxiv_ids(request_ds): | |
| requested_arxiv_ids = [] | |
| for request_d in request_ds['train']: | |
| arxiv_ids = request_d['Requested arXiv IDs'] | |
| requested_arxiv_ids = requested_arxiv_ids + arxiv_ids | |
| requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids}) | |
| return requested_arxiv_ids_df | |
| def _initialize_paper_info(source_ds): | |
| title2qna, date2qna = {}, {} | |
| date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) | |
| arxivid2data = {} | |
| count = 0 | |
| if len(source_ds["train"]) > 1: | |
| for data in source_ds["train"]: | |
| if data["title"] != "dummy": | |
| date = data["target_date"].strftime("%Y-%m-%d") | |
| arxiv_id = data["arxiv_id"] | |
| if date in date2qna: | |
| papers = copy.deepcopy(date2qna[date]) | |
| for paper in papers: | |
| if paper["title"] == data["title"]: | |
| if _count_nans(paper) > _count_nans(data): | |
| date2qna[date].remove(paper) | |
| date2qna[date].append(data) | |
| del papers | |
| else: | |
| date2qna[date] = [data] | |
| for date in date2qna: | |
| year, month, day = date.split("-") | |
| papers = date2qna[date] | |
| for paper in papers: | |
| title2qna[paper["title"]] = paper | |
| arxivid2data[paper['arxiv_id']] = {"idx": count, "paper": paper} | |
| date_dict[year][month][day].append(paper) | |
| titles = [f"[{v['arxiv_id']}] {k}" for k, v in title2qna.items()] | |
| return titles, date_dict, arxivid2data | |
| else: | |
| return [], {}, {} | |
| def initialize_data(source_data_repo_id, request_data_repo_id): | |
| global date_dict, arxivid2data | |
| global requested_arxiv_ids_df | |
| source_ds = datasets.load_dataset(source_data_repo_id) | |
| request_ds = datasets.load_dataset(request_data_repo_id) | |
| titles, date_dict, arxivid2data = _initialize_paper_info(source_ds) | |
| requested_arxiv_ids_df = _initialize_requested_arxiv_ids(request_ds) | |
| return ( | |
| titles, date_dict, requested_arxiv_ids_df, arxivid2data | |
| ) | |
| def update_dataframe(): | |
| global request_arxiv_repo_id | |
| request_ds = datasets.load_dataset(request_arxiv_repo_id) | |
| return _initialize_requested_arxiv_ids(request_ds) | |
| def initialize_repos( | |
| source_data_repo_id, request_data_repo_id, hf_token | |
| ): | |
| if create_hf_hub(source_data_repo_id, hf_token) is False: | |
| print(f"{source_data_repo_id} repository already exists") | |
| else: | |
| dummy_row = {"title": ["dummy"]} | |
| ds = Dataset.from_dict(dummy_row) | |
| ds.push_to_hub(source_data_repo_id, token=hf_token) | |
| if create_hf_hub(request_data_repo_id, hf_token) is False: | |
| print(f"{request_data_repo_id} repository already exists") | |
| else: | |
| df = pd.DataFrame(data={"Requested arXiv IDs": [["top"]]}) | |
| ds = Dataset.from_pandas(df) | |
| ds.push_to_hub(request_data_repo_id, token=hf_token) | |
| def get_secrets(): | |
| global gemini_api_key | |
| global hf_token | |
| global request_arxiv_repo_id | |
| global dataset_repo_id | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| hf_token = os.getenv("HF_TOKEN") | |
| dataset_repo_id = os.getenv("SOURCE_DATA_REPO_ID") | |
| request_arxiv_repo_id = os.getenv("REQUEST_DATA_REPO_ID") | |
| restart_repo_id = os.getenv("RESTART_TARGET_SPACE_REPO_ID", "chansung/paper_qa") | |
| return ( | |
| gemini_api_key, | |
| hf_token, | |
| dataset_repo_id, | |
| request_arxiv_repo_id, | |
| restart_repo_id | |
| ) |