Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,6 @@ from requests.exceptions import HTTPError
|
|
9 |
def parse_roboflow_url(url):
|
10 |
"""
|
11 |
Extract workspace/project and version from a Roboflow Universe URL.
|
12 |
-
Example: https://universe.roboflow.com/.../dataset/6
|
13 |
Returns (workspace, project, version)
|
14 |
"""
|
15 |
pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
|
@@ -21,18 +20,17 @@ def parse_roboflow_url(url):
|
|
21 |
|
22 |
def fetch_metadata(api_key, workspace, project, version):
|
23 |
"""
|
24 |
-
Fetch metadata
|
25 |
-
Raises ValueError on HTTP errors.
|
26 |
"""
|
27 |
endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
|
|
|
28 |
try:
|
29 |
-
resp = requests.get(endpoint, params={"api_key": api_key})
|
30 |
resp.raise_for_status()
|
31 |
except HTTPError:
|
32 |
if resp.status_code == 401:
|
33 |
raise ValueError("Unauthorized: check your API key.")
|
34 |
else:
|
35 |
-
raise ValueError(f"Error
|
36 |
data = resp.json()
|
37 |
total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
|
38 |
classes = data.get("project", {}).get("classes", {})
|
@@ -41,15 +39,19 @@ def fetch_metadata(api_key, workspace, project, version):
|
|
41 |
|
42 |
def aggregate_datasets(api_key, entries):
|
43 |
"""
|
44 |
-
Given
|
45 |
-
|
46 |
-
|
|
|
47 |
"""
|
48 |
total_images = 0
|
49 |
class_counts = {}
|
50 |
class_sources = {}
|
51 |
for url, fname, lineno in entries:
|
52 |
-
|
|
|
|
|
|
|
53 |
imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
|
54 |
total_images += imgs
|
55 |
for cls, cnt in cls_map.items():
|
@@ -61,14 +63,14 @@ def aggregate_datasets(api_key, entries):
|
|
61 |
|
62 |
def make_bar_chart(counts):
|
63 |
"""
|
64 |
-
|
65 |
"""
|
66 |
fig, ax = plt.subplots()
|
67 |
-
|
68 |
-
|
69 |
-
ax.bar(range(len(
|
70 |
-
ax.set_xticks(range(len(
|
71 |
-
ax.set_xticklabels(
|
72 |
ax.set_ylabel("Image Count")
|
73 |
ax.set_title("Class Distribution")
|
74 |
fig.tight_layout()
|
@@ -77,14 +79,19 @@ def make_bar_chart(counts):
|
|
77 |
|
78 |
def load_datasets(api_key, file_objs):
|
79 |
"""
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
"""
|
|
|
|
|
|
|
83 |
entries = []
|
84 |
seen = set()
|
85 |
for fobj in file_objs:
|
86 |
fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
|
87 |
-
# read raw
|
88 |
try:
|
89 |
raw = fobj.read()
|
90 |
except:
|
@@ -93,6 +100,7 @@ def load_datasets(api_key, file_objs):
|
|
93 |
with open(fobj, "rb") as fh:
|
94 |
raw = fh.read()
|
95 |
text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw
|
|
|
96 |
for i, line in enumerate(text.splitlines(), start=1):
|
97 |
url = line.strip()
|
98 |
if url and url not in seen:
|
@@ -101,30 +109,41 @@ def load_datasets(api_key, file_objs):
|
|
101 |
|
102 |
total, counts, sources = aggregate_datasets(api_key, entries)
|
103 |
|
104 |
-
# build dataframe
|
105 |
-
|
106 |
|
107 |
-
# build markdown
|
108 |
md_lines = []
|
109 |
for cls in counts:
|
110 |
-
links = ", ".join(f"[{
|
111 |
md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}")
|
112 |
md_sources = "\n".join(md_lines)
|
113 |
|
114 |
fig = make_bar_chart(counts)
|
115 |
-
return str(total),
|
116 |
|
117 |
|
118 |
def update_classes(df_data):
|
119 |
"""
|
120 |
-
|
121 |
-
|
|
|
122 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
combined = {}
|
124 |
-
for
|
|
|
|
|
|
|
125 |
if not name:
|
126 |
continue
|
127 |
-
key = name.strip().lower()
|
128 |
try:
|
129 |
val = int(cnt)
|
130 |
except:
|
@@ -132,10 +151,11 @@ def update_classes(df_data):
|
|
132 |
combined[key] = combined.get(key, 0) + val
|
133 |
|
134 |
total = sum(combined.values())
|
135 |
-
|
136 |
fig = make_bar_chart(combined)
|
137 |
md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)
|
138 |
-
|
|
|
139 |
|
140 |
|
141 |
def build_ui():
|
@@ -149,7 +169,10 @@ def build_ui():
|
|
149 |
load_btn = gr.Button("Load Datasets")
|
150 |
total_out = gr.Textbox(label="Total Images", interactive=False)
|
151 |
df = gr.Dataframe(
|
152 |
-
headers=["Class Name", "Count"],
|
|
|
|
|
|
|
153 |
)
|
154 |
plot = gr.Plot()
|
155 |
json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
|
@@ -160,12 +183,12 @@ def build_ui():
|
|
160 |
load_btn.click(
|
161 |
fn=load_datasets,
|
162 |
inputs=[api_input, files],
|
163 |
-
outputs=[total_out, df, plot, json_out, md_out]
|
164 |
)
|
165 |
update_btn.click(
|
166 |
fn=update_classes,
|
167 |
inputs=[df],
|
168 |
-
outputs=[total_out, df, plot, json_out, md_out]
|
169 |
)
|
170 |
|
171 |
return demo
|
|
|
9 |
def parse_roboflow_url(url):
|
10 |
"""
|
11 |
Extract workspace/project and version from a Roboflow Universe URL.
|
|
|
12 |
Returns (workspace, project, version)
|
13 |
"""
|
14 |
pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
|
|
|
20 |
|
21 |
def fetch_metadata(api_key, workspace, project, version):
|
22 |
"""
|
23 |
+
Fetch metadata from Roboflow. Raises ValueError on HTTP errors.
|
|
|
24 |
"""
|
25 |
endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
|
26 |
+
resp = requests.get(endpoint, params={"api_key": api_key})
|
27 |
try:
|
|
|
28 |
resp.raise_for_status()
|
29 |
except HTTPError:
|
30 |
if resp.status_code == 401:
|
31 |
raise ValueError("Unauthorized: check your API key.")
|
32 |
else:
|
33 |
+
raise ValueError(f"Error {resp.status_code} for {workspace}/{project}/{version}")
|
34 |
data = resp.json()
|
35 |
total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
|
36 |
classes = data.get("project", {}).get("classes", {})
|
|
|
39 |
|
40 |
def aggregate_datasets(api_key, entries):
|
41 |
"""
|
42 |
+
Given list of (url, file, line), returns:
|
43 |
+
- total_images
|
44 |
+
- dict[class_name_lowercase] = aggregated count
|
45 |
+
- dict[class_name_lowercase] = set(source URLs)
|
46 |
"""
|
47 |
total_images = 0
|
48 |
class_counts = {}
|
49 |
class_sources = {}
|
50 |
for url, fname, lineno in entries:
|
51 |
+
try:
|
52 |
+
ws, proj, ver = parse_roboflow_url(url)
|
53 |
+
except ValueError:
|
54 |
+
raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}")
|
55 |
imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
|
56 |
total_images += imgs
|
57 |
for cls, cnt in cls_map.items():
|
|
|
63 |
|
64 |
def make_bar_chart(counts):
|
65 |
"""
|
66 |
+
Build a bar chart from a {label: value} dict.
|
67 |
"""
|
68 |
fig, ax = plt.subplots()
|
69 |
+
labels = list(counts.keys())
|
70 |
+
values = list(counts.values())
|
71 |
+
ax.bar(range(len(labels)), values)
|
72 |
+
ax.set_xticks(range(len(labels)))
|
73 |
+
ax.set_xticklabels(labels, rotation=45, ha="right")
|
74 |
ax.set_ylabel("Image Count")
|
75 |
ax.set_title("Class Distribution")
|
76 |
fig.tight_layout()
|
|
|
79 |
|
80 |
def load_datasets(api_key, file_objs):
|
81 |
"""
|
82 |
+
1) Ensure API key present
|
83 |
+
2) Read & dedupe URLs from each uploaded .txt
|
84 |
+
3) Fetch & aggregate metadata
|
85 |
+
Returns: total, table_data, figure, json_counts, markdown_sources
|
86 |
"""
|
87 |
+
if not api_key or not api_key.strip():
|
88 |
+
raise ValueError("Please enter your Roboflow API Key before loading datasets.")
|
89 |
+
|
90 |
entries = []
|
91 |
seen = set()
|
92 |
for fobj in file_objs:
|
93 |
fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
|
94 |
+
# read raw bytes or dict-data or file path
|
95 |
try:
|
96 |
raw = fobj.read()
|
97 |
except:
|
|
|
100 |
with open(fobj, "rb") as fh:
|
101 |
raw = fh.read()
|
102 |
text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw
|
103 |
+
|
104 |
for i, line in enumerate(text.splitlines(), start=1):
|
105 |
url = line.strip()
|
106 |
if url and url not in seen:
|
|
|
109 |
|
110 |
total, counts, sources = aggregate_datasets(api_key, entries)
|
111 |
|
112 |
+
# build dataframe rows
|
113 |
+
table_data = [[cls, counts[cls]] for cls in counts]
|
114 |
|
115 |
+
# build clickable markdown per-class
|
116 |
md_lines = []
|
117 |
for cls in counts:
|
118 |
+
links = ", ".join(f"[{url.split('/')[-1]}]({url})" for url in sources[cls])
|
119 |
md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}")
|
120 |
md_sources = "\n".join(md_lines)
|
121 |
|
122 |
fig = make_bar_chart(counts)
|
123 |
+
return str(total), table_data, fig, json.dumps(counts, indent=2), md_sources
|
124 |
|
125 |
|
126 |
def update_classes(df_data):
|
127 |
"""
|
128 |
+
Convert df_data into a list-of-lists (if needed),
|
129 |
+
merge duplicate/lowercased class names, and recalc all outputs.
|
130 |
+
Returns: total, updated_table, figure, json_counts, markdown_summary
|
131 |
"""
|
132 |
+
# convert Pandas DataFrame or NumPy array into list-of-lists
|
133 |
+
if not isinstance(df_data, list):
|
134 |
+
if hasattr(df_data, "to_numpy"):
|
135 |
+
df_data = df_data.to_numpy().tolist()
|
136 |
+
elif hasattr(df_data, "tolist"):
|
137 |
+
df_data = df_data.tolist()
|
138 |
+
|
139 |
combined = {}
|
140 |
+
for row in df_data:
|
141 |
+
if len(row) < 2:
|
142 |
+
continue
|
143 |
+
name, cnt = row[0], row[1]
|
144 |
if not name:
|
145 |
continue
|
146 |
+
key = str(name).strip().lower()
|
147 |
try:
|
148 |
val = int(cnt)
|
149 |
except:
|
|
|
151 |
combined[key] = combined.get(key, 0) + val
|
152 |
|
153 |
total = sum(combined.values())
|
154 |
+
updated_table = [[k, combined[k]] for k in combined]
|
155 |
fig = make_bar_chart(combined)
|
156 |
md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)
|
157 |
+
|
158 |
+
return str(total), updated_table, fig, json.dumps(combined, indent=2), md_summary
|
159 |
|
160 |
|
161 |
def build_ui():
|
|
|
169 |
load_btn = gr.Button("Load Datasets")
|
170 |
total_out = gr.Textbox(label="Total Images", interactive=False)
|
171 |
df = gr.Dataframe(
|
172 |
+
headers=["Class Name", "Count"],
|
173 |
+
row_count=(1, None),
|
174 |
+
col_count=2,
|
175 |
+
interactive=True
|
176 |
)
|
177 |
plot = gr.Plot()
|
178 |
json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
|
|
|
183 |
load_btn.click(
|
184 |
fn=load_datasets,
|
185 |
inputs=[api_input, files],
|
186 |
+
outputs=[total_out, df, plot, json_out, md_out]
|
187 |
)
|
188 |
update_btn.click(
|
189 |
fn=update_classes,
|
190 |
inputs=[df],
|
191 |
+
outputs=[total_out, df, plot, json_out, md_out]
|
192 |
)
|
193 |
|
194 |
return demo
|