File size: 8,270 Bytes
8ae6e06
1ba559b
21838aa
aff4682
5c9caf7
2648a73
ca77cfe
5df4f2c
c92e75b
364e3c8
5004744
30a8037
9d9bd80
4a960ac
 
30a8037
 
ca77cfe
5c4dda7
364e3c8
5c4dda7
364e3c8
ea98634
5df4f2c
5c4dda7
 
 
a177cdd
5c4dda7
 
 
 
364e3c8
 
1169603
 
364e3c8
3902287
364e3c8
1169603
 
e512c8b
 
 
5df4f2c
 
5c4dda7
1169603
600edc1
e512c8b
364e3c8
1169603
e512c8b
 
1169603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e512c8b
 
 
1169603
 
 
e512c8b
1169603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1adb437
f7470d0
5df4f2c
 
 
ca77cfe
 
 
f7470d0
 
ca77cfe
 
395811c
35fbe95
 
 
 
 
 
 
c5b1ed0
2a3eeb6
30a8037
19dbeb8
 
30a8037
c5b1ed0
30a8037
 
c5b1ed0
19dbeb8
5c4dda7
 
f7470d0
ea98634
 
5c4dda7
f7470d0
19dbeb8
1adb437
 
 
5c4dda7
 
 
bd56ce3
f7470d0
 
ea98634
0593dcc
 
 
 
ea98634
f7470d0
c5b1ed0
f7470d0
364e3c8
ea98634
0593dcc
76f9d9f
 
e409f29
 
76f9d9f
e409f29
5c4dda7
 
0593dcc
ca77cfe
d0621bc
 
 
5c4dda7
f7470d0
0593dcc
 
35fbe95
 
 
0593dcc
9169b10
0593dcc
 
f7470d0
cfadd90
ea98634
5c4dda7
7b61419
5c4dda7
f7470d0
cfadd90
c5b1ed0
0dd7b13
 
 
 
860284b
76f9d9f
 
7b61419
c5b1ed0
f7470d0
 
 
c5b1ed0
 
f7470d0
 
 
1adb437
860284b
 
 
 
 
5c4dda7
c5b1ed0
cfadd90
d29f83a
7b61419
c5b1ed0
5df4f2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import gradio as gr
from PIL import Image
import pandas as pd
import json
import os
import requests
import threading
import time
from datetime import datetime
from collections import OrderedDict

# ==== CONFIG ==== #
PERSISTENT_DIR = "/data" if os.path.exists("/data") else "."
RESULTS_FILE = os.path.join(PERSISTENT_DIR, "review_later.csv")
GROUPS_FILE = "review_later.json"
MAPPING_FILE = "file_mapping.csv"
DRIVE_LINK_TEMPLATE = "https://drive.google.com/uc?id={}"
PRELOAD_AHEAD = 20
CACHE_LIMIT = 100  # Cache up to 100 groups in memory

# ==== Globals ==== #
image_cache = OrderedDict()
preloaded_until = PRELOAD_AHEAD - 1 
loading_in_progress = set()
file_dict = pd.read_csv(MAPPING_FILE).set_index("name")["id"].to_dict()
with open(GROUPS_FILE) as f:
    sample_names = json.load(f)

# ==== Core Functions ==== #
def get_drive_image_url(file_name):
    file_id = file_dict.get(file_name)
    return DRIVE_LINK_TEMPLATE.format(file_id) if file_id else None
def load_group_with_cache(group, resize=(256, 256)):
    key = tuple(group)
    
    # βœ… If already cached, return it
    if key in image_cache:
        print(f"🟒 Getting key from cache: {key}")
        return image_cache[key]
    
    # ⏳ Wait if another thread is loading it
    if key in loading_in_progress:
        print(f"⏳ Load already in progress: {key}")
        while key in loading_in_progress:
            time.sleep(0.05)
        return image_cache.get(key, [])

    # πŸ”„ Start loading
    print(f"πŸ”„ Key not in cache, loading from Drive: {key}")
    loading_in_progress.add(key)
    imgs = []

    try:
        for file_name in group:
            try:
                url = get_drive_image_url(file_name)
                if not url:
                    raise ValueError(f"No URL found for file: {file_name}")

                response = requests.get(url, stream=True, timeout=10)
                img = Image.open(response.raw).convert("RGB").resize(resize)
                imgs.append(img)

            except Exception as e:
                print(f"❌ Error loading {file_name} (URL={url}): {e}")
                print(f"β›” Aborting group load: {key}")
                return []  # Immediately abort and skip this group

        # βœ… Only cache and return if all images loaded
        image_cache[key] = imgs
        if len(image_cache) > CACHE_LIMIT:
            image_cache.popitem(last=False)
        return imgs

    finally:
        loading_in_progress.discard(key)


# def load_group_with_cache(group, resize=(256, 256)):
#     key = tuple(group)
#     if key in image_cache:
#         print(f"🟒 Getting key from cache: {key}")
#         return image_cache[key]
#     if key in loading_in_progress:
#         print(f"⏳ Load already in progress: {key}")
#         while key in loading_in_progress:
#             time.sleep(0.05)
#         print(f"🟒 Loaded after wait: {key}")
#         return image_cache.get(key, [])

#     print(f"πŸ”„ Key not in cache, loading from Drive: {key}")
#     loading_in_progress.add(key)
#     imgs = []
#     try:
#         for file_name in group:
#             url = get_drive_image_url(file_name)
#             response = requests.get(url, stream=True, timeout=10)
#             img = Image.open(response.raw).convert("RGB").resize(resize)
#             imgs.append(img)
#     except Exception as e:
#         print(f"❌ Error loading {file_name}: {e}, skipping group = {key}")
#         return [] 
#     finally:
#         image_cache[key] = imgs
#         if len(image_cache) > CACHE_LIMIT:
#             image_cache.popitem(last=False)
#         loading_in_progress.discard(key)
#     return imgs

def preload_ahead_async(remaining_groups, start_idx):
    if start_idx % PRELOAD_AHEAD != 0:
        return
    print(f"πŸš€ Preloading batch starting at index {start_idx}")
    def _preload():
        for offset in range(PRELOAD_AHEAD):
            idx = start_idx + offset
            if idx < len(remaining_groups):
                group = remaining_groups[idx]
                load_group_with_cache(group)
    threading.Thread(target=_preload, daemon=True).start()

def preload_ahead_sync(remaining_groups, start_idx):
    for offset in range(PRELOAD_AHEAD):
        idx = start_idx + offset
        if idx < len(remaining_groups):
            group = remaining_groups[idx]
            load_group_with_cache(group)
            
def load_reviewed_ids():
    try:
        reviewed = pd.read_csv(RESULTS_FILE).to_dict(orient="records")
        reviewed_ids = {tuple(sorted(json.loads(r["group"]))) for r in reviewed}
        return reviewed, reviewed_ids
    except FileNotFoundError:
        return [], set()

def get_remaining_groups():
    reviewed, reviewed_ids = load_reviewed_ids()
    remaining = [g for g in sample_names if tuple(sorted(g)) not in reviewed_ids]
    return reviewed, remaining

def review_group(decision, current_index, remaining_groups):
    global preloaded_until

    reviewed, _ = load_reviewed_ids()
    current_group = remaining_groups[current_index]
    reviewed.append({
        "group": json.dumps(sorted(current_group)),
        "decision": decision
    })
    pd.DataFrame(reviewed).to_csv(RESULTS_FILE, index=False)

    next_index = current_index + 1
    print(f"Next index= {next_index}, preloaded_until= {preloaded_until}")
    if next_index < len(remaining_groups):
        next_group = remaining_groups[next_index]

        # βœ… Preload next chunk if it's the first time we reach it
        if next_index % PRELOAD_AHEAD == 0 and next_index > preloaded_until:
            preload_ahead_async(remaining_groups, next_index + PRELOAD_AHEAD)
            preloaded_until = next_index + PRELOAD_AHEAD - 1

        return load_group_with_cache(next_group), next_index, remaining_groups, f"Group {next_index+1} / {len(remaining_groups)}"
    else:
        return [], next_index, remaining_groups, "βœ… All groups reviewed!"



def prepare_download():
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    current_df = pd.read_csv(RESULTS_FILE)
    filename = f"reviewed_groups_{timestamp}_{current_df.shape[0]}.csv"
    dst = os.path.join(PERSISTENT_DIR, filename)
    current_df.to_csv(dst, index=False)
    return dst


def get_first_group():
    global preloaded_until
    preloaded_until = PRELOAD_AHEAD - 1  # πŸ›  Reset preload window

    reviewed, remaining = get_remaining_groups()
    current_index = 0

    # 🧠 Block until groups 0–19 are cached
    print("🧠 Preloading first batch synchronously...")
    preload_ahead_sync(remaining, current_index)

    # βœ… Immediately start async preload for 20–39
    print("🧠 Preloading *next* batch asynchronously...")
    preload_ahead_async(remaining, PRELOAD_AHEAD)

    group = remaining[current_index]
    return load_group_with_cache(group), current_index, remaining, f"Group {current_index+1} / {len(remaining)}"

# ==== Gradio UI ==== #
with gr.Blocks() as demo:
    current_index = gr.State(0)
    remaining_groups = gr.State([])
    gallery = gr.Gallery(label="Group", columns=4, height="auto")
    progress_text = gr.Markdown()

    with gr.Row():
        like = gr.Button("πŸ‘ Like")
        dislike = gr.Button("πŸ‘Ž Dislike")
        review_later = gr.Button("πŸ” Review Later")
        download_btn = gr.Button("⬇️ Download Results")
        download_file = gr.File(label="Download CSV")

    like.click(
        fn=lambda idx, groups: review_group("like", idx, groups),
        inputs=[current_index, remaining_groups],
        outputs=[gallery, current_index, remaining_groups, progress_text]
    )
    dislike.click(
        fn=lambda idx, groups: review_group("dislike", idx, groups),
        inputs=[current_index, remaining_groups],
        outputs=[gallery, current_index, remaining_groups, progress_text]
    )
    review_later.click(
        fn=lambda idx, groups: review_group("review_later", idx, groups),
        inputs=[current_index, remaining_groups],
        outputs=[gallery, current_index, remaining_groups, progress_text]
    )
    download_btn.click(fn=prepare_download, inputs=[], outputs=[download_file])

    demo.load(fn=get_first_group, outputs=[gallery, current_index, remaining_groups, progress_text])


if __name__ == "__main__":
    demo.launch(allowed_paths=["/data"])