Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import pandas as pd | |
import json | |
import os | |
import requests | |
import threading | |
import time | |
from datetime import datetime | |
from collections import OrderedDict | |
# ==== CONFIG ==== # | |
PERSISTENT_DIR = "/data" if os.path.exists("/data") else "." | |
RESULTS_FILE = os.path.join(PERSISTENT_DIR, "review_later.csv") | |
GROUPS_FILE = "review_later.json" | |
MAPPING_FILE = "file_mapping.csv" | |
DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}" | |
PRELOAD_AHEAD = 20 | |
CACHE_LIMIT = 100 # Cache up to 100 groups in memory | |
# ==== Globals ==== # | |
image_cache = OrderedDict() | |
preloaded_until = PRELOAD_AHEAD - 1 | |
loading_in_progress = set() | |
file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict() | |
with open(GROUPS_FILE) as f: | |
sample_names = json.load(f) | |
# ==== Core Functions ==== # | |
def get_drive_image_url(file_name): | |
file_id = file_dict.get(file_name) | |
return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None | |
def load_group_with_cache(group, resize=(256, 256)): | |
key = tuple(group) | |
# β If already cached, return it | |
if key in image_cache: | |
print(f"π’ Getting key from cache: {key}") | |
return image_cache[key] | |
# β³ Wait if another thread is loading it | |
if key in loading_in_progress: | |
print(f"β³ Load already in progress: {key}") | |
while key in loading_in_progress: | |
time.sleep(0.05) | |
return image_cache.get(key, []) | |
# π Start loading | |
print(f"π Key not in cache, loading from Drive: {key}") | |
loading_in_progress.add(key) | |
imgs = [] | |
try: | |
for file_name in group: | |
try: | |
url = get_drive_image_url(file_name) | |
if not url: | |
raise ValueError(f"No URL found for file: {file_name}") | |
response = requests.get(url, stream=True, timeout=10) | |
img = Image.open(response.raw).convert("RGB").resize(resize) | |
imgs.append(img) | |
except Exception as e: | |
print(f"β Error loading {file_name} (URL={url}): {e}") | |
print(f"β Aborting group load: {key}") | |
return [] # Immediately abort and skip this group | |
# β Only cache and return if all images loaded | |
image_cache[key] = imgs | |
if len(image_cache) > CACHE_LIMIT: | |
image_cache.popitem(last=False) | |
return imgs | |
finally: | |
loading_in_progress.discard(key) | |
# def load_group_with_cache(group, resize=(256, 256)): | |
# key = tuple(group) | |
# if key in image_cache: | |
# print(f"π’ Getting key from cache: {key}") | |
# return image_cache[key] | |
# if key in loading_in_progress: | |
# print(f"β³ Load already in progress: {key}") | |
# while key in loading_in_progress: | |
# time.sleep(0.05) | |
# print(f"π’ Loaded after wait: {key}") | |
# return image_cache.get(key, []) | |
# print(f"π Key not in cache, loading from Drive: {key}") | |
# loading_in_progress.add(key) | |
# imgs = [] | |
# try: | |
# for file_name in group: | |
# url = get_drive_image_url(file_name) | |
# response = requests.get(url, stream=True, timeout=10) | |
# img = Image.open(response.raw).convert("RGB").resize(resize) | |
# imgs.append(img) | |
# except Exception as e: | |
# print(f"β Error loading {file_name}: {e}, skipping group = {key}") | |
# return [] | |
# finally: | |
# image_cache[key] = imgs | |
# if len(image_cache) > CACHE_LIMIT: | |
# image_cache.popitem(last=False) | |
# loading_in_progress.discard(key) | |
# return imgs | |
def preload_ahead_async(remaining_groups, start_idx): | |
if start_idx % PRELOAD_AHEAD != 0: | |
return | |
print(f"π Preloading batch starting at index {start_idx}") | |
def _preload(): | |
for offset in range(PRELOAD_AHEAD): | |
idx = start_idx + offset | |
if idx < len(remaining_groups): | |
group = remaining_groups[idx] | |
load_group_with_cache(group) | |
threading.Thread(target=_preload, daemon=True).start() | |
def preload_ahead_sync(remaining_groups, start_idx): | |
for offset in range(PRELOAD_AHEAD): | |
idx = start_idx + offset | |
if idx < len(remaining_groups): | |
group = remaining_groups[idx] | |
load_group_with_cache(group) | |
def load_reviewed_ids(): | |
try: | |
reviewed = pd.read_csv(RESULTS_FILE).to_dict(orient="records") | |
reviewed_ids = {tuple(sorted(json.loads(r["group"]))) for r in reviewed} | |
return reviewed, reviewed_ids | |
except FileNotFoundError: | |
return [], set() | |
def get_remaining_groups(): | |
reviewed, reviewed_ids = load_reviewed_ids() | |
remaining = [g for g in sample_names if tuple(sorted(g)) not in reviewed_ids] | |
return reviewed, remaining | |
def review_group(decision, current_index, remaining_groups): | |
global preloaded_until | |
reviewed, _ = load_reviewed_ids() | |
current_group = remaining_groups[current_index] | |
reviewed.append({ | |
"group": json.dumps(sorted(current_group)), | |
"decision": decision | |
}) | |
pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False) | |
next_index = current_index + 1 | |
print(f"Next index= {next_index}, preloaded_until= {preloaded_until}") | |
if next_index < len(remaining_groups): | |
next_group = remaining_groups[next_index] | |
# β Preload next chunk if it's the first time we reach it | |
if next_index % PRELOAD_AHEAD == 0 and next_index > preloaded_until: | |
preload_ahead_async(remaining_groups, next_index + PRELOAD_AHEAD) | |
preloaded_until = next_index + PRELOAD_AHEAD - 1 | |
return load_group_with_cache(next_group), next_index, remaining_groups, f"Group {next_index+1} / {len(remaining_groups)}" | |
else: | |
return [], next_index, remaining_groups, "β All groups reviewed!" | |
def prepare_download(): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
current_df = pd.read_csv(RESULTS_FILE) | |
filename = f"reviewed_groups_{timestamp}_{current_df.shape[0]}.csv" | |
dst = os.path.join(PERSISTENT_DIR, filename) | |
current_df.to_csv(dst, index=False) | |
return dst | |
def get_first_group(): | |
global preloaded_until | |
preloaded_until = PRELOAD_AHEAD - 1 # π Reset preload window | |
reviewed, remaining = get_remaining_groups() | |
current_index = 0 | |
# π§ Block until groups 0β19 are cached | |
print("π§ Preloading first batch synchronously...") | |
preload_ahead_sync(remaining, current_index) | |
# β Immediately start async preload for 20β39 | |
print("π§ Preloading *next* batch asynchronously...") | |
preload_ahead_async(remaining, PRELOAD_AHEAD) | |
group = remaining[current_index] | |
return load_group_with_cache(group), current_index, remaining, f"Group {current_index+1} / {len(remaining)}" | |
# ==== Gradio UI ==== # | |
with gr.Blocks() as demo: | |
current_index = gr.State(0) | |
remaining_groups = gr.State([]) | |
gallery = gr.Gallery(label="Group", columns=4, height="auto") | |
progress_text = gr.Markdown() | |
with gr.Row(): | |
like = gr.Button("π Like") | |
dislike = gr.Button("π Dislike") | |
review_later = gr.Button("π Review Later") | |
download_btn = gr.Button("β¬οΈ Download Results") | |
download_file = gr.File(label="Download CSV") | |
like.click( | |
fn=lambda idx, groups: review_group("like", idx, groups), | |
inputs=[current_index, remaining_groups], | |
outputs=[gallery, current_index, remaining_groups, progress_text] | |
) | |
dislike.click( | |
fn=lambda idx, groups: review_group("dislike", idx, groups), | |
inputs=[current_index, remaining_groups], | |
outputs=[gallery, current_index, remaining_groups, progress_text] | |
) | |
review_later.click( | |
fn=lambda idx, groups: review_group("review_later", idx, groups), | |
inputs=[current_index, remaining_groups], | |
outputs=[gallery, current_index, remaining_groups, progress_text] | |
) | |
download_btn.click(fn=prepare_download, inputs=[], outputs=[download_file]) | |
demo.load(fn=get_first_group, outputs=[gallery, current_index, remaining_groups, progress_text]) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["/data"]) | |