Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,6 @@ import pickle
|
|
8 |
from tqdm import tqdm
|
9 |
from datetime import datetime
|
10 |
from collections import OrderedDict
|
11 |
-
import threading
|
12 |
-
|
13 |
|
14 |
# ==== CONFIG ==== #
|
15 |
PERSISTENT_DIR = "/data" if os.path.exists("/data") else "."
|
@@ -17,31 +15,25 @@ RESULTS_FILE = os.path.join(PERSISTENT_DIR, "review_results.csv")
|
|
17 |
GROUPS_FILE = "subgroups_4876.json"
|
18 |
MAPPING_FILE = "file_mapping.csv"
|
19 |
DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}"
|
20 |
-
|
21 |
-
|
22 |
-
CACHE_LIMIT = 30 # Feel free to tweak this
|
23 |
|
|
|
24 |
image_cache = OrderedDict()
|
|
|
|
|
|
|
25 |
|
26 |
-
# ====
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
if os.path.exists(path):
|
31 |
-
os.remove(path)
|
32 |
-
print(f"ποΈ Deleted {path}")
|
33 |
-
|
34 |
|
35 |
-
def preload_next_group(remaining_groups):
|
36 |
-
if len(remaining_groups) >= 2:
|
37 |
-
next_group = remaining_groups[1] # next after the one being shown
|
38 |
-
load_group_with_cache(next_group) # this fills the cache
|
39 |
-
|
40 |
def load_group_with_cache(group, resize=(256, 256)):
|
41 |
key = tuple(group)
|
42 |
if key in image_cache:
|
43 |
return image_cache[key]
|
44 |
-
|
45 |
imgs = []
|
46 |
for file_name in group:
|
47 |
try:
|
@@ -52,16 +44,20 @@ def load_group_with_cache(group, resize=(256, 256)):
|
|
52 |
except Exception as e:
|
53 |
print(f"β Error loading {file_name}: {e}")
|
54 |
imgs.append(None)
|
55 |
-
|
56 |
image_cache[key] = imgs
|
57 |
if len(image_cache) > CACHE_LIMIT:
|
58 |
-
image_cache.popitem(last=False)
|
59 |
|
60 |
return imgs
|
61 |
-
|
62 |
-
def
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def load_reviewed_ids():
|
67 |
try:
|
@@ -73,69 +69,44 @@ def load_reviewed_ids():
|
|
73 |
def get_remaining_groups():
|
74 |
reviewed, reviewed_ids = load_reviewed_ids()
|
75 |
remaining = [g for g in sample_names if tuple(g) not in reviewed_ids]
|
76 |
-
return reviewed,
|
77 |
-
|
78 |
-
def review_group(decision,
|
79 |
-
reviewed,
|
80 |
-
|
81 |
-
reviewed.append({
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
_, _, remaining = get_remaining_groups()
|
94 |
-
if remaining:
|
95 |
-
current_group = remaining[0]
|
96 |
-
next_images = load_group_with_cache(current_group)
|
97 |
-
threading.Thread(target=preload_next_group, args=(remaining,)).start()
|
98 |
-
return next_images, current_group, f"Group {len(reviewed)+1} / {len(sample_names)}"
|
99 |
-
|
100 |
else:
|
101 |
-
return [],
|
102 |
-
|
103 |
|
104 |
def prepare_download():
|
105 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
106 |
filename = f"review_results_{timestamp}.csv"
|
107 |
-
|
108 |
-
src = RESULTS_FILE
|
109 |
dst = os.path.join(PERSISTENT_DIR, filename)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
return load_group_with_cache(group), group, f"Group {len(reviewed)+1} / {len(sample_names)}"
|
123 |
-
else:
|
124 |
-
return [], None, "β
All groups reviewed!"
|
125 |
-
|
126 |
-
# ==== Load Data ====
|
127 |
-
file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict()
|
128 |
-
with open(GROUPS_FILE) as f:
|
129 |
-
sample_names = json.load(f)
|
130 |
-
|
131 |
-
# ==== Gradio UI ====
|
132 |
with gr.Blocks() as demo:
|
133 |
-
|
134 |
-
|
135 |
-
progress_text = gr.Markdown()
|
136 |
-
|
137 |
-
with gr.Blocks() as demo:
|
138 |
-
current_group = gr.State(value=None)
|
139 |
gallery = gr.Gallery(label="Group", columns=4, height="auto")
|
140 |
progress_text = gr.Markdown()
|
141 |
|
@@ -143,34 +114,21 @@ with gr.Blocks() as demo:
|
|
143 |
like = gr.Button("π Like")
|
144 |
dislike = gr.Button("π Dislike")
|
145 |
download_btn = gr.Button("β¬οΈ Download Results")
|
146 |
-
|
147 |
download_file = gr.File(label="Download CSV")
|
148 |
-
|
149 |
-
|
150 |
|
151 |
like.click(
|
152 |
-
fn=lambda
|
153 |
-
inputs=[
|
154 |
-
outputs=[gallery,
|
155 |
)
|
156 |
-
|
157 |
dislike.click(
|
158 |
-
fn=lambda
|
159 |
-
inputs=[
|
160 |
-
outputs=[gallery,
|
161 |
-
)
|
162 |
-
|
163 |
-
download_btn.click(
|
164 |
-
fn=prepare_download,
|
165 |
-
inputs=[],
|
166 |
-
outputs=[download_file]
|
167 |
-
)
|
168 |
-
demo.load(
|
169 |
-
fn=get_first_group,
|
170 |
-
outputs=[gallery, current_group, progress_text]
|
171 |
)
|
|
|
172 |
|
|
|
173 |
|
174 |
if __name__ == "__main__":
|
175 |
demo.launch(allowed_paths=["/data"])
|
176 |
-
|
|
|
8 |
from tqdm import tqdm
|
9 |
from datetime import datetime
|
10 |
from collections import OrderedDict
|
|
|
|
|
11 |
|
12 |
# ==== CONFIG ==== #
|
13 |
PERSISTENT_DIR = "/data" if os.path.exists("/data") else "."
|
|
|
15 |
GROUPS_FILE = "subgroups_4876.json"
|
16 |
MAPPING_FILE = "file_mapping.csv"
|
17 |
DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}"
|
18 |
+
BATCH_SIZE = 50
|
19 |
+
CACHE_LIMIT = 100 # Cache up to 100 groups in memory
|
|
|
20 |
|
21 |
+
# ==== Globals ==== #
|
22 |
image_cache = OrderedDict()
|
23 |
+
file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict()
|
24 |
+
with open(GROUPS_FILE) as f:
|
25 |
+
sample_names = json.load(f)
|
26 |
|
27 |
+
# ==== Core Functions ==== #
|
28 |
+
def get_drive_image_url(file_name):
|
29 |
+
file_id = file_dict.get(file_name)
|
30 |
+
return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None
|
|
|
|
|
|
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
def load_group_with_cache(group, resize=(256, 256)):
|
33 |
key = tuple(group)
|
34 |
if key in image_cache:
|
35 |
return image_cache[key]
|
36 |
+
|
37 |
imgs = []
|
38 |
for file_name in group:
|
39 |
try:
|
|
|
44 |
except Exception as e:
|
45 |
print(f"β Error loading {file_name}: {e}")
|
46 |
imgs.append(None)
|
47 |
+
|
48 |
image_cache[key] = imgs
|
49 |
if len(image_cache) > CACHE_LIMIT:
|
50 |
+
image_cache.popitem(last=False)
|
51 |
|
52 |
return imgs
|
53 |
+
|
54 |
+
def preload_batch(start_idx, batch_size=BATCH_SIZE):
|
55 |
+
end_idx = min(start_idx + batch_size, len(sample_names))
|
56 |
+
batch_groups = sample_names[start_idx:end_idx]
|
57 |
+
preloaded = []
|
58 |
+
for group in tqdm(batch_groups, desc="Preloading batch"):
|
59 |
+
preloaded.append(load_group_with_cache(group))
|
60 |
+
return preloaded
|
61 |
|
62 |
def load_reviewed_ids():
|
63 |
try:
|
|
|
69 |
def get_remaining_groups():
|
70 |
reviewed, reviewed_ids = load_reviewed_ids()
|
71 |
remaining = [g for g in sample_names if tuple(g) not in reviewed_ids]
|
72 |
+
return reviewed, remaining
|
73 |
+
|
74 |
+
def review_group(decision, current_index, preloaded_batch):
|
75 |
+
reviewed, _ = load_reviewed_ids()
|
76 |
+
current_group = sample_names[current_index]
|
77 |
+
reviewed.append({"group": json.dumps(current_group), "decision": decision})
|
78 |
+
pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False)
|
79 |
+
|
80 |
+
next_index = current_index + 1
|
81 |
+
if next_index < len(sample_names):
|
82 |
+
batch_start = (next_index // BATCH_SIZE) * BATCH_SIZE
|
83 |
+
if next_index % BATCH_SIZE == 0:
|
84 |
+
return [], next_index, [], f"β³ Preloading next batch..."
|
85 |
+
else:
|
86 |
+
next_group = sample_names[next_index]
|
87 |
+
return load_group_with_cache(next_group), next_index, preloaded_batch, f"Group {next_index+1} / {len(sample_names)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
else:
|
89 |
+
return [], next_index, preloaded_batch, "β
All groups reviewed!"
|
|
|
90 |
|
91 |
def prepare_download():
|
92 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
93 |
filename = f"review_results_{timestamp}.csv"
|
|
|
|
|
94 |
dst = os.path.join(PERSISTENT_DIR, filename)
|
95 |
+
pd.read_csv(RESULTS_FILE).to_csv(dst, index=False)
|
96 |
+
return dst
|
97 |
+
|
98 |
+
def get_first_batch():
|
99 |
+
reviewed, remaining = get_remaining_groups()
|
100 |
+
current_index = len(reviewed)
|
101 |
+
batch_start = (current_index // BATCH_SIZE) * BATCH_SIZE
|
102 |
+
preloaded_batch = preload_batch(batch_start)
|
103 |
+
group = sample_names[current_index]
|
104 |
+
return load_group_with_cache(group), current_index, preloaded_batch, f"Group {current_index+1} / {len(sample_names)}"
|
105 |
+
|
106 |
+
# ==== Gradio UI ==== #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with gr.Blocks() as demo:
|
108 |
+
current_index = gr.State(0)
|
109 |
+
preloaded_batch = gr.State([])
|
|
|
|
|
|
|
|
|
110 |
gallery = gr.Gallery(label="Group", columns=4, height="auto")
|
111 |
progress_text = gr.Markdown()
|
112 |
|
|
|
114 |
like = gr.Button("π Like")
|
115 |
dislike = gr.Button("π Dislike")
|
116 |
download_btn = gr.Button("β¬οΈ Download Results")
|
|
|
117 |
download_file = gr.File(label="Download CSV")
|
|
|
|
|
118 |
|
119 |
like.click(
|
120 |
+
fn=lambda idx, batch: review_group("like", idx, batch),
|
121 |
+
inputs=[current_index, preloaded_batch],
|
122 |
+
outputs=[gallery, current_index, preloaded_batch, progress_text]
|
123 |
)
|
|
|
124 |
dislike.click(
|
125 |
+
fn=lambda idx, batch: review_group("dislike", idx, batch),
|
126 |
+
inputs=[current_index, preloaded_batch],
|
127 |
+
outputs=[gallery, current_index, preloaded_batch, progress_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
)
|
129 |
+
download_btn.click(fn=prepare_download, inputs=[], outputs=[download_file])
|
130 |
|
131 |
+
demo.load(fn=get_first_batch, outputs=[gallery, current_index, preloaded_batch, progress_text])
|
132 |
|
133 |
if __name__ == "__main__":
|
134 |
demo.launch(allowed_paths=["/data"])
|
|