subwalls-demo / app.py
naamaslomi's picture
Update app.py
4a960ac verified
import gradio as gr
from PIL import Image
import pandas as pd
import json
import os
import requests
import threading
import time
from datetime import datetime
from collections import OrderedDict
# ==== CONFIG ==== #
PERSISTENT_DIR = "/data" if os.path.exists("/data") else "."
RESULTS_FILE = os.path.join(PERSISTENT_DIR, "review_later.csv")
GROUPS_FILE = "review_later.json"
MAPPING_FILE = "file_mapping.csv"
DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}"
PRELOAD_AHEAD = 20
CACHE_LIMIT = 100 # Cache up to 100 groups in memory
# ==== Globals ==== #
image_cache = OrderedDict()
preloaded_until = PRELOAD_AHEAD - 1
loading_in_progress = set()
file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict()
with open(GROUPS_FILE) as f:
sample_names = json.load(f)
# ==== Core Functions ==== #
def get_drive_image_url(file_name):
file_id = file_dict.get(file_name)
return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None
def load_group_with_cache(group, resize=(256, 256)):
key = tuple(group)
# βœ… If already cached, return it
if key in image_cache:
print(f"🟒 Getting key from cache: {key}")
return image_cache[key]
# ⏳ Wait if another thread is loading it
if key in loading_in_progress:
print(f"⏳ Load already in progress: {key}")
while key in loading_in_progress:
time.sleep(0.05)
return image_cache.get(key, [])
# πŸ”„ Start loading
print(f"πŸ”„ Key not in cache, loading from Drive: {key}")
loading_in_progress.add(key)
imgs = []
try:
for file_name in group:
try:
url = get_drive_image_url(file_name)
if not url:
raise ValueError(f"No URL found for file: {file_name}")
response = requests.get(url, stream=True, timeout=10)
img = Image.open(response.raw).convert("RGB").resize(resize)
imgs.append(img)
except Exception as e:
print(f"❌ Error loading {file_name} (URL={url}): {e}")
print(f"β›” Aborting group load: {key}")
return [] # Immediately abort and skip this group
# βœ… Only cache and return if all images loaded
image_cache[key] = imgs
if len(image_cache) > CACHE_LIMIT:
image_cache.popitem(last=False)
return imgs
finally:
loading_in_progress.discard(key)
# def load_group_with_cache(group, resize=(256, 256)):
# key = tuple(group)
# if key in image_cache:
# print(f"🟒 Getting key from cache: {key}")
# return image_cache[key]
# if key in loading_in_progress:
# print(f"⏳ Load already in progress: {key}")
# while key in loading_in_progress:
# time.sleep(0.05)
# print(f"🟒 Loaded after wait: {key}")
# return image_cache.get(key, [])
# print(f"πŸ”„ Key not in cache, loading from Drive: {key}")
# loading_in_progress.add(key)
# imgs = []
# try:
# for file_name in group:
# url = get_drive_image_url(file_name)
# response = requests.get(url, stream=True, timeout=10)
# img = Image.open(response.raw).convert("RGB").resize(resize)
# imgs.append(img)
# except Exception as e:
# print(f"❌ Error loading {file_name}: {e}, skipping group = {key}")
# return []
# finally:
# image_cache[key] = imgs
# if len(image_cache) > CACHE_LIMIT:
# image_cache.popitem(last=False)
# loading_in_progress.discard(key)
# return imgs
def preload_ahead_async(remaining_groups, start_idx):
if start_idx % PRELOAD_AHEAD != 0:
return
print(f"πŸš€ Preloading batch starting at index {start_idx}")
def _preload():
for offset in range(PRELOAD_AHEAD):
idx = start_idx + offset
if idx < len(remaining_groups):
group = remaining_groups[idx]
load_group_with_cache(group)
threading.Thread(target=_preload, daemon=True).start()
def preload_ahead_sync(remaining_groups, start_idx):
for offset in range(PRELOAD_AHEAD):
idx = start_idx + offset
if idx < len(remaining_groups):
group = remaining_groups[idx]
load_group_with_cache(group)
def load_reviewed_ids():
try:
reviewed = pd.read_csv(RESULTS_FILE).to_dict(orient="records")
reviewed_ids = {tuple(sorted(json.loads(r["group"]))) for r in reviewed}
return reviewed, reviewed_ids
except FileNotFoundError:
return [], set()
def get_remaining_groups():
reviewed, reviewed_ids = load_reviewed_ids()
remaining = [g for g in sample_names if tuple(sorted(g)) not in reviewed_ids]
return reviewed, remaining
def review_group(decision, current_index, remaining_groups):
global preloaded_until
reviewed, _ = load_reviewed_ids()
current_group = remaining_groups[current_index]
reviewed.append({
"group": json.dumps(sorted(current_group)),
"decision": decision
})
pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False)
next_index = current_index + 1
print(f"Next index= {next_index}, preloaded_until= {preloaded_until}")
if next_index < len(remaining_groups):
next_group = remaining_groups[next_index]
# βœ… Preload next chunk if it's the first time we reach it
if next_index % PRELOAD_AHEAD == 0 and next_index > preloaded_until:
preload_ahead_async(remaining_groups, next_index + PRELOAD_AHEAD)
preloaded_until = next_index + PRELOAD_AHEAD - 1
return load_group_with_cache(next_group), next_index, remaining_groups, f"Group {next_index+1} / {len(remaining_groups)}"
else:
return [], next_index, remaining_groups, "βœ… All groups reviewed!"
def prepare_download():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
current_df = pd.read_csv(RESULTS_FILE)
filename = f"reviewed_groups_{timestamp}_{current_df.shape[0]}.csv"
dst = os.path.join(PERSISTENT_DIR, filename)
current_df.to_csv(dst, index=False)
return dst
def get_first_group():
global preloaded_until
preloaded_until = PRELOAD_AHEAD - 1 # πŸ›  Reset preload window
reviewed, remaining = get_remaining_groups()
current_index = 0
# 🧠 Block until groups 0–19 are cached
print("🧠 Preloading first batch synchronously...")
preload_ahead_sync(remaining, current_index)
# βœ… Immediately start async preload for 20–39
print("🧠 Preloading *next* batch asynchronously...")
preload_ahead_async(remaining, PRELOAD_AHEAD)
group = remaining[current_index]
return load_group_with_cache(group), current_index, remaining, f"Group {current_index+1} / {len(remaining)}"
# ==== Gradio UI ==== #
with gr.Blocks() as demo:
current_index = gr.State(0)
remaining_groups = gr.State([])
gallery = gr.Gallery(label="Group", columns=4, height="auto")
progress_text = gr.Markdown()
with gr.Row():
like = gr.Button("πŸ‘ Like")
dislike = gr.Button("πŸ‘Ž Dislike")
review_later = gr.Button("πŸ” Review Later")
download_btn = gr.Button("⬇️ Download Results")
download_file = gr.File(label="Download CSV")
like.click(
fn=lambda idx, groups: review_group("like", idx, groups),
inputs=[current_index, remaining_groups],
outputs=[gallery, current_index, remaining_groups, progress_text]
)
dislike.click(
fn=lambda idx, groups: review_group("dislike", idx, groups),
inputs=[current_index, remaining_groups],
outputs=[gallery, current_index, remaining_groups, progress_text]
)
review_later.click(
fn=lambda idx, groups: review_group("review_later", idx, groups),
inputs=[current_index, remaining_groups],
outputs=[gallery, current_index, remaining_groups, progress_text]
)
download_btn.click(fn=prepare_download, inputs=[], outputs=[download_file])
demo.load(fn=get_first_group, outputs=[gallery, current_index, remaining_groups, progress_text])
if __name__ == "__main__":
demo.launch(allowed_paths=["/data"])