import gradio as gr from PIL import Image import pandas as pd import json import os import requests import threading import time from datetime import datetime from collections import OrderedDict # ==== CONFIG ==== # PERSISTENT_DIR = "/data" if os.path.exists("/data") else "." RESULTS_FILE = os.path.join(PERSISTENT_DIR, "review_later.csv") GROUPS_FILE = "review_later.json" MAPPING_FILE = "file_mapping.csv" DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}" PRELOAD_AHEAD = 20 CACHE_LIMIT = 100 # Cache up to 100 groups in memory # ==== Globals ==== # image_cache = OrderedDict() preloaded_until = PRELOAD_AHEAD - 1 loading_in_progress = set() file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict() with open(GROUPS_FILE) as f: sample_names = json.load(f) # ==== Core Functions ==== # def get_drive_image_url(file_name): file_id = file_dict.get(file_name) return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None def load_group_with_cache(group, resize=(256, 256)): key = tuple(group) # ✅ If already cached, return it if key in image_cache: print(f"🟢 Getting key from cache: {key}") return image_cache[key] # ⏳ Wait if another thread is loading it if key in loading_in_progress: print(f"⏳ Load already in progress: {key}") while key in loading_in_progress: time.sleep(0.05) return image_cache.get(key, []) # 🔄 Start loading print(f"🔄 Key not in cache, loading from Drive: {key}") loading_in_progress.add(key) imgs = [] try: for file_name in group: try: url = get_drive_image_url(file_name) if not url: raise ValueError(f"No URL found for file: {file_name}") response = requests.get(url, stream=True, timeout=10) img = Image.open(response.raw).convert("RGB").resize(resize) imgs.append(img) except Exception as e: print(f"❌ Error loading {file_name} (URL={url}): {e}") print(f"⛔ Aborting group load: {key}") return [] # Immediately abort and skip this group # ✅ Only cache and return if all images loaded image_cache[key] = imgs if len(image_cache) > CACHE_LIMIT: image_cache.popitem(last=False) return imgs finally: loading_in_progress.discard(key) # def load_group_with_cache(group, resize=(256, 256)): # key = tuple(group) # if key in image_cache: # print(f"🟢 Getting key from cache: {key}") # return image_cache[key] # if key in loading_in_progress: # print(f"⏳ Load already in progress: {key}") # while key in loading_in_progress: # time.sleep(0.05) # print(f"🟢 Loaded after wait: {key}") # return image_cache.get(key, []) # print(f"🔄 Key not in cache, loading from Drive: {key}") # loading_in_progress.add(key) # imgs = [] # try: # for file_name in group: # url = get_drive_image_url(file_name) # response = requests.get(url, stream=True, timeout=10) # img = Image.open(response.raw).convert("RGB").resize(resize) # imgs.append(img) # except Exception as e: # print(f"❌ Error loading {file_name}: {e}, skipping group = {key}") # return [] # finally: # image_cache[key] = imgs # if len(image_cache) > CACHE_LIMIT: # image_cache.popitem(last=False) # loading_in_progress.discard(key) # return imgs def preload_ahead_async(remaining_groups, start_idx): if start_idx % PRELOAD_AHEAD != 0: return print(f"🚀 Preloading batch starting at index {start_idx}") def _preload(): for offset in range(PRELOAD_AHEAD): idx = start_idx + offset if idx < len(remaining_groups): group = remaining_groups[idx] load_group_with_cache(group) threading.Thread(target=_preload, daemon=True).start() def preload_ahead_sync(remaining_groups, start_idx): for offset in range(PRELOAD_AHEAD): idx = start_idx + offset if idx < len(remaining_groups): group = remaining_groups[idx] load_group_with_cache(group) def load_reviewed_ids(): try: reviewed = pd.read_csv(RESULTS_FILE).to_dict(orient="records") reviewed_ids = {tuple(sorted(json.loads(r["group"]))) for r in reviewed} return reviewed, reviewed_ids except FileNotFoundError: return [], set() def get_remaining_groups(): reviewed, reviewed_ids = load_reviewed_ids() remaining = [g for g in sample_names if tuple(sorted(g)) not in reviewed_ids] return reviewed, remaining def review_group(decision, current_index, remaining_groups): global preloaded_until reviewed, _ = load_reviewed_ids() current_group = remaining_groups[current_index] reviewed.append({ "group": json.dumps(sorted(current_group)), "decision": decision }) pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False) next_index = current_index + 1 print(f"Next index= {next_index}, preloaded_until= {preloaded_until}") if next_index < len(remaining_groups): next_group = remaining_groups[next_index] # ✅ Preload next chunk if it's the first time we reach it if next_index % PRELOAD_AHEAD == 0 and next_index > preloaded_until: preload_ahead_async(remaining_groups, next_index + PRELOAD_AHEAD) preloaded_until = next_index + PRELOAD_AHEAD - 1 return load_group_with_cache(next_group), next_index, remaining_groups, f"Group {next_index+1} / {len(remaining_groups)}" else: return [], next_index, remaining_groups, "✅ All groups reviewed!" def prepare_download(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") current_df = pd.read_csv(RESULTS_FILE) filename = f"reviewed_groups_{timestamp}_{current_df.shape[0]}.csv" dst = os.path.join(PERSISTENT_DIR, filename) current_df.to_csv(dst, index=False) return dst def get_first_group(): global preloaded_until preloaded_until = PRELOAD_AHEAD - 1 # 🛠 Reset preload window reviewed, remaining = get_remaining_groups() current_index = 0 # 🧠 Block until groups 0–19 are cached print("🧠 Preloading first batch synchronously...") preload_ahead_sync(remaining, current_index) # ✅ Immediately start async preload for 20–39 print("🧠 Preloading *next* batch asynchronously...") preload_ahead_async(remaining, PRELOAD_AHEAD) group = remaining[current_index] return load_group_with_cache(group), current_index, remaining, f"Group {current_index+1} / {len(remaining)}" # ==== Gradio UI ==== # with gr.Blocks() as demo: current_index = gr.State(0) remaining_groups = gr.State([]) gallery = gr.Gallery(label="Group", columns=4, height="auto") progress_text = gr.Markdown() with gr.Row(): like = gr.Button("👍 Like") dislike = gr.Button("👎 Dislike") review_later = gr.Button("🔁 Review Later") download_btn = gr.Button("⬇️ Download Results") download_file = gr.File(label="Download CSV") like.click( fn=lambda idx, groups: review_group("like", idx, groups), inputs=[current_index, remaining_groups], outputs=[gallery, current_index, remaining_groups, progress_text] ) dislike.click( fn=lambda idx, groups: review_group("dislike", idx, groups), inputs=[current_index, remaining_groups], outputs=[gallery, current_index, remaining_groups, progress_text] ) review_later.click( fn=lambda idx, groups: review_group("review_later", idx, groups), inputs=[current_index, remaining_groups], outputs=[gallery, current_index, remaining_groups, progress_text] ) download_btn.click(fn=prepare_download, inputs=[], outputs=[download_file]) demo.load(fn=get_first_group, outputs=[gallery, current_index, remaining_groups, progress_text]) if __name__ == "__main__": demo.launch(allowed_paths=["/data"])