naamaslomi commited on
Commit
5c4dda7
Β·
verified Β·
1 Parent(s): 979de1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -102
app.py CHANGED
@@ -8,8 +8,6 @@ import pickle
8
  from tqdm import tqdm
9
  from datetime import datetime
10
  from collections import OrderedDict
11
- import threading
12
-
13
 
14
  # ==== CONFIG ==== #
15
  PERSISTENT_DIR = "/data" if os.path.exists("/data") else "."
@@ -17,31 +15,25 @@ RESULTS_FILE = os.path.join(PERSISTENT_DIR, "review_results.csv")
17
  GROUPS_FILE = "subgroups_4876.json"
18
  MAPPING_FILE = "file_mapping.csv"
19
  DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}"
20
- CACHE_FILE = os.path.join(PERSISTENT_DIR, "groups_cache.pkl")
21
- RESET = False # Set to True to clear previous results and cache
22
- CACHE_LIMIT = 30 # Feel free to tweak this
23
 
 
24
  image_cache = OrderedDict()
 
 
 
25
 
26
- # ==== Optional Reset ====
27
- if RESET:
28
- for filename in [RESULTS_FILE, CACHE_FILE]:
29
- path = os.path.join(PERSISTENT_DIR, filename)
30
- if os.path.exists(path):
31
- os.remove(path)
32
- print(f"πŸ—‘οΈ Deleted {path}")
33
-
34
 
35
- def preload_next_group(remaining_groups):
36
- if len(remaining_groups) >= 2:
37
- next_group = remaining_groups[1] # next after the one being shown
38
- load_group_with_cache(next_group) # this fills the cache
39
-
40
  def load_group_with_cache(group, resize=(256, 256)):
41
  key = tuple(group)
42
  if key in image_cache:
43
  return image_cache[key]
44
-
45
  imgs = []
46
  for file_name in group:
47
  try:
@@ -52,16 +44,20 @@ def load_group_with_cache(group, resize=(256, 256)):
52
  except Exception as e:
53
  print(f"❌ Error loading {file_name}: {e}")
54
  imgs.append(None)
55
-
56
  image_cache[key] = imgs
57
  if len(image_cache) > CACHE_LIMIT:
58
- image_cache.popitem(last=False) # Remove oldest group
59
 
60
  return imgs
61
- # ==== Helpers ====
62
- def get_drive_image_url(file_name):
63
- file_id = file_dict.get(file_name)
64
- return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None
 
 
 
 
65
 
66
  def load_reviewed_ids():
67
  try:
@@ -73,69 +69,44 @@ def load_reviewed_ids():
73
  def get_remaining_groups():
74
  reviewed, reviewed_ids = load_reviewed_ids()
75
  remaining = [g for g in sample_names if tuple(g) not in reviewed_ids]
76
- return reviewed, reviewed_ids, remaining
77
-
78
- def review_group(decision, group):
79
- reviewed, reviewed_ids = load_reviewed_ids()
80
-
81
- reviewed.append({
82
- "group": json.dumps(group),
83
- "decision": decision
84
- })
85
-
86
- try:
87
- os.makedirs(os.path.dirname(RESULTS_FILE), exist_ok=True)
88
- pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False)
89
- print(f"βœ… Saved to {RESULTS_FILE}")
90
- except Exception as e:
91
- print(f"❌ Error saving results: {e}")
92
-
93
- _, _, remaining = get_remaining_groups()
94
- if remaining:
95
- current_group = remaining[0]
96
- next_images = load_group_with_cache(current_group)
97
- threading.Thread(target=preload_next_group, args=(remaining,)).start()
98
- return next_images, current_group, f"Group {len(reviewed)+1} / {len(sample_names)}"
99
-
100
  else:
101
- return [], None, "βœ… All groups reviewed!"
102
-
103
 
104
  def prepare_download():
105
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
106
  filename = f"review_results_{timestamp}.csv"
107
-
108
- src = RESULTS_FILE
109
  dst = os.path.join(PERSISTENT_DIR, filename)
110
- try:
111
- pd.read_csv(src).to_csv(dst, index=False)
112
- print(f"πŸ“ Prepared file: {dst}")
113
- return dst
114
- except Exception as e:
115
- print(f"⚠️ Error preparing download: {e}")
116
- return None
117
-
118
- def get_first_group():
119
- reviewed, _, remaining = get_remaining_groups()
120
- if remaining:
121
- group = remaining[0]
122
- return load_group_with_cache(group), group, f"Group {len(reviewed)+1} / {len(sample_names)}"
123
- else:
124
- return [], None, "βœ… All groups reviewed!"
125
-
126
- # ==== Load Data ====
127
- file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict()
128
- with open(GROUPS_FILE) as f:
129
- sample_names = json.load(f)
130
-
131
- # ==== Gradio UI ====
132
  with gr.Blocks() as demo:
133
- current_group = gr.State(value=None)
134
- gallery = gr.Gallery(label="Group", columns=4, height="auto")
135
- progress_text = gr.Markdown()
136
-
137
- with gr.Blocks() as demo:
138
- current_group = gr.State(value=None)
139
  gallery = gr.Gallery(label="Group", columns=4, height="auto")
140
  progress_text = gr.Markdown()
141
 
@@ -143,34 +114,21 @@ with gr.Blocks() as demo:
143
  like = gr.Button("πŸ‘ Like")
144
  dislike = gr.Button("πŸ‘Ž Dislike")
145
  download_btn = gr.Button("⬇️ Download Results")
146
-
147
  download_file = gr.File(label="Download CSV")
148
-
149
-
150
 
151
  like.click(
152
- fn=lambda group: review_group("like", group),
153
- inputs=[current_group],
154
- outputs=[gallery, current_group, progress_text]
155
  )
156
-
157
  dislike.click(
158
- fn=lambda group: review_group("dislike", group),
159
- inputs=[current_group],
160
- outputs=[gallery, current_group, progress_text]
161
- )
162
-
163
- download_btn.click(
164
- fn=prepare_download,
165
- inputs=[],
166
- outputs=[download_file]
167
- )
168
- demo.load(
169
- fn=get_first_group,
170
- outputs=[gallery, current_group, progress_text]
171
  )
 
172
 
 
173
 
174
  if __name__ == "__main__":
175
  demo.launch(allowed_paths=["/data"])
176
-
 
8
  from tqdm import tqdm
9
  from datetime import datetime
10
  from collections import OrderedDict
 
 
11
 
12
  # ==== CONFIG ==== #
13
  PERSISTENT_DIR = "/data" if os.path.exists("/data") else "."
 
15
  GROUPS_FILE = "subgroups_4876.json"
16
  MAPPING_FILE = "file_mapping.csv"
17
  DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}"
18
+ BATCH_SIZE = 50
19
+ CACHE_LIMIT = 100 # Cache up to 100 groups in memory
 
20
 
21
+ # ==== Globals ==== #
22
  image_cache = OrderedDict()
23
+ file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict()
24
+ with open(GROUPS_FILE) as f:
25
+ sample_names = json.load(f)
26
 
27
+ # ==== Core Functions ==== #
28
+ def get_drive_image_url(file_name):
29
+ file_id = file_dict.get(file_name)
30
+ return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None
 
 
 
 
31
 
 
 
 
 
 
32
  def load_group_with_cache(group, resize=(256, 256)):
33
  key = tuple(group)
34
  if key in image_cache:
35
  return image_cache[key]
36
+
37
  imgs = []
38
  for file_name in group:
39
  try:
 
44
  except Exception as e:
45
  print(f"❌ Error loading {file_name}: {e}")
46
  imgs.append(None)
47
+
48
  image_cache[key] = imgs
49
  if len(image_cache) > CACHE_LIMIT:
50
+ image_cache.popitem(last=False)
51
 
52
  return imgs
53
+
54
+ def preload_batch(start_idx, batch_size=BATCH_SIZE):
55
+ end_idx = min(start_idx + batch_size, len(sample_names))
56
+ batch_groups = sample_names[start_idx:end_idx]
57
+ preloaded = []
58
+ for group in tqdm(batch_groups, desc="Preloading batch"):
59
+ preloaded.append(load_group_with_cache(group))
60
+ return preloaded
61
 
62
  def load_reviewed_ids():
63
  try:
 
69
  def get_remaining_groups():
70
  reviewed, reviewed_ids = load_reviewed_ids()
71
  remaining = [g for g in sample_names if tuple(g) not in reviewed_ids]
72
+ return reviewed, remaining
73
+
74
+ def review_group(decision, current_index, preloaded_batch):
75
+ reviewed, _ = load_reviewed_ids()
76
+ current_group = sample_names[current_index]
77
+ reviewed.append({"group": json.dumps(current_group), "decision": decision})
78
+ pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False)
79
+
80
+ next_index = current_index + 1
81
+ if next_index < len(sample_names):
82
+ batch_start = (next_index // BATCH_SIZE) * BATCH_SIZE
83
+ if next_index % BATCH_SIZE == 0:
84
+ return [], next_index, [], f"⏳ Preloading next batch..."
85
+ else:
86
+ next_group = sample_names[next_index]
87
+ return load_group_with_cache(next_group), next_index, preloaded_batch, f"Group {next_index+1} / {len(sample_names)}"
 
 
 
 
 
 
 
 
88
  else:
89
+ return [], next_index, preloaded_batch, "βœ… All groups reviewed!"
 
90
 
91
  def prepare_download():
92
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
93
  filename = f"review_results_{timestamp}.csv"
 
 
94
  dst = os.path.join(PERSISTENT_DIR, filename)
95
+ pd.read_csv(RESULTS_FILE).to_csv(dst, index=False)
96
+ return dst
97
+
98
+ def get_first_batch():
99
+ reviewed, remaining = get_remaining_groups()
100
+ current_index = len(reviewed)
101
+ batch_start = (current_index // BATCH_SIZE) * BATCH_SIZE
102
+ preloaded_batch = preload_batch(batch_start)
103
+ group = sample_names[current_index]
104
+ return load_group_with_cache(group), current_index, preloaded_batch, f"Group {current_index+1} / {len(sample_names)}"
105
+
106
+ # ==== Gradio UI ==== #
 
 
 
 
 
 
 
 
 
 
107
  with gr.Blocks() as demo:
108
+ current_index = gr.State(0)
109
+ preloaded_batch = gr.State([])
 
 
 
 
110
  gallery = gr.Gallery(label="Group", columns=4, height="auto")
111
  progress_text = gr.Markdown()
112
 
 
114
  like = gr.Button("πŸ‘ Like")
115
  dislike = gr.Button("πŸ‘Ž Dislike")
116
  download_btn = gr.Button("⬇️ Download Results")
 
117
  download_file = gr.File(label="Download CSV")
 
 
118
 
119
  like.click(
120
+ fn=lambda idx, batch: review_group("like", idx, batch),
121
+ inputs=[current_index, preloaded_batch],
122
+ outputs=[gallery, current_index, preloaded_batch, progress_text]
123
  )
 
124
  dislike.click(
125
+ fn=lambda idx, batch: review_group("dislike", idx, batch),
126
+ inputs=[current_index, preloaded_batch],
127
+ outputs=[gallery, current_index, preloaded_batch, progress_text]
 
 
 
 
 
 
 
 
 
 
128
  )
129
+ download_btn.click(fn=prepare_download, inputs=[], outputs=[download_file])
130
 
131
+ demo.load(fn=get_first_batch, outputs=[gallery, current_index, preloaded_batch, progress_text])
132
 
133
  if __name__ == "__main__":
134
  demo.launch(allowed_paths=["/data"])