wuhp commited on
Commit
02f8610
·
verified ·
1 Parent(s): 5067213

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -32
app.py CHANGED
@@ -9,7 +9,6 @@ from requests.exceptions import HTTPError
9
  def parse_roboflow_url(url):
10
  """
11
  Extract workspace/project and version from a Roboflow Universe URL.
12
- Example: https://universe.roboflow.com/.../dataset/6
13
  Returns (workspace, project, version)
14
  """
15
  pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
@@ -21,18 +20,17 @@ def parse_roboflow_url(url):
21
 
22
  def fetch_metadata(api_key, workspace, project, version):
23
  """
24
- Fetch metadata for a given project version from Roboflow API.
25
- Raises ValueError on HTTP errors.
26
  """
27
  endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
 
28
  try:
29
- resp = requests.get(endpoint, params={"api_key": api_key})
30
  resp.raise_for_status()
31
  except HTTPError:
32
  if resp.status_code == 401:
33
  raise ValueError("Unauthorized: check your API key.")
34
  else:
35
- raise ValueError(f"Error fetching {workspace}/{project}/{version}: {resp.status_code}")
36
  data = resp.json()
37
  total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
38
  classes = data.get("project", {}).get("classes", {})
@@ -41,15 +39,19 @@ def fetch_metadata(api_key, workspace, project, version):
41
 
42
  def aggregate_datasets(api_key, entries):
43
  """
44
- Given API key and list of (url, file, line) tuples,
45
- returns total_images, aggregated lowercase class counts,
46
- and per-class source URLs.
 
47
  """
48
  total_images = 0
49
  class_counts = {}
50
  class_sources = {}
51
  for url, fname, lineno in entries:
52
- ws, proj, ver = parse_roboflow_url(url)
 
 
 
53
  imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
54
  total_images += imgs
55
  for cls, cnt in cls_map.items():
@@ -61,14 +63,14 @@ def aggregate_datasets(api_key, entries):
61
 
62
  def make_bar_chart(counts):
63
  """
64
- Return a matplotlib figure showing a bar chart of counts dict.
65
  """
66
  fig, ax = plt.subplots()
67
- keys = list(counts.keys())
68
- vals = list(counts.values())
69
- ax.bar(range(len(keys)), vals)
70
- ax.set_xticks(range(len(keys)))
71
- ax.set_xticklabels(keys, rotation=45, ha="right")
72
  ax.set_ylabel("Image Count")
73
  ax.set_title("Class Distribution")
74
  fig.tight_layout()
@@ -77,14 +79,19 @@ def make_bar_chart(counts):
77
 
78
  def load_datasets(api_key, file_objs):
79
  """
80
- Read uploaded .txt files, dedupe URLs, fetch metadata,
81
- and return all outputs for the UI.
 
 
82
  """
 
 
 
83
  entries = []
84
  seen = set()
85
  for fobj in file_objs:
86
  fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
87
- # read raw content
88
  try:
89
  raw = fobj.read()
90
  except:
@@ -93,6 +100,7 @@ def load_datasets(api_key, file_objs):
93
  with open(fobj, "rb") as fh:
94
  raw = fh.read()
95
  text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw
 
96
  for i, line in enumerate(text.splitlines(), start=1):
97
  url = line.strip()
98
  if url and url not in seen:
@@ -101,30 +109,41 @@ def load_datasets(api_key, file_objs):
101
 
102
  total, counts, sources = aggregate_datasets(api_key, entries)
103
 
104
- # build dataframe list
105
- df_data = [[cls, counts[cls]] for cls in counts]
106
 
107
- # build markdown of sources
108
  md_lines = []
109
  for cls in counts:
110
- links = ", ".join(f"[{s.split('/')[-1]}]({s})" for s in sources[cls])
111
  md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}")
112
  md_sources = "\n".join(md_lines)
113
 
114
  fig = make_bar_chart(counts)
115
- return str(total), df_data, fig, json.dumps(counts, indent=2), md_sources
116
 
117
 
118
  def update_classes(df_data):
119
  """
120
- Take the edited table rows, merge duplicates (lowercase),
121
- and return updated total, df, chart, JSON, and markdown.
 
122
  """
 
 
 
 
 
 
 
123
  combined = {}
124
- for name, cnt in df_data:
 
 
 
125
  if not name:
126
  continue
127
- key = name.strip().lower()
128
  try:
129
  val = int(cnt)
130
  except:
@@ -132,10 +151,11 @@ def update_classes(df_data):
132
  combined[key] = combined.get(key, 0) + val
133
 
134
  total = sum(combined.values())
135
- updated_df = [[k, combined[k]] for k in combined]
136
  fig = make_bar_chart(combined)
137
  md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)
138
- return str(total), updated_df, fig, json.dumps(combined, indent=2), md_summary
 
139
 
140
 
141
  def build_ui():
@@ -149,7 +169,10 @@ def build_ui():
149
  load_btn = gr.Button("Load Datasets")
150
  total_out = gr.Textbox(label="Total Images", interactive=False)
151
  df = gr.Dataframe(
152
- headers=["Class Name", "Count"], row_count=(1, None), col_count=2, interactive=True
 
 
 
153
  )
154
  plot = gr.Plot()
155
  json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
@@ -160,12 +183,12 @@ def build_ui():
160
  load_btn.click(
161
  fn=load_datasets,
162
  inputs=[api_input, files],
163
- outputs=[total_out, df, plot, json_out, md_out],
164
  )
165
  update_btn.click(
166
  fn=update_classes,
167
  inputs=[df],
168
- outputs=[total_out, df, plot, json_out, md_out],
169
  )
170
 
171
  return demo
 
9
  def parse_roboflow_url(url):
10
  """
11
  Extract workspace/project and version from a Roboflow Universe URL.
 
12
  Returns (workspace, project, version)
13
  """
14
  pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)"
 
20
 
21
  def fetch_metadata(api_key, workspace, project, version):
22
  """
23
+ Fetch metadata from Roboflow. Raises ValueError on HTTP errors.
 
24
  """
25
  endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}"
26
+ resp = requests.get(endpoint, params={"api_key": api_key})
27
  try:
 
28
  resp.raise_for_status()
29
  except HTTPError:
30
  if resp.status_code == 401:
31
  raise ValueError("Unauthorized: check your API key.")
32
  else:
33
+ raise ValueError(f"Error {resp.status_code} for {workspace}/{project}/{version}")
34
  data = resp.json()
35
  total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0)
36
  classes = data.get("project", {}).get("classes", {})
 
39
 
40
  def aggregate_datasets(api_key, entries):
41
  """
42
+ Given list of (url, file, line), returns:
43
+ - total_images
44
+ - dict[class_name_lowercase] = aggregated count
45
+ - dict[class_name_lowercase] = set(source URLs)
46
  """
47
  total_images = 0
48
  class_counts = {}
49
  class_sources = {}
50
  for url, fname, lineno in entries:
51
+ try:
52
+ ws, proj, ver = parse_roboflow_url(url)
53
+ except ValueError:
54
+ raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}")
55
  imgs, cls_map = fetch_metadata(api_key, ws, proj, ver)
56
  total_images += imgs
57
  for cls, cnt in cls_map.items():
 
63
 
64
  def make_bar_chart(counts):
65
  """
66
+ Build a bar chart from a {label: value} dict.
67
  """
68
  fig, ax = plt.subplots()
69
+ labels = list(counts.keys())
70
+ values = list(counts.values())
71
+ ax.bar(range(len(labels)), values)
72
+ ax.set_xticks(range(len(labels)))
73
+ ax.set_xticklabels(labels, rotation=45, ha="right")
74
  ax.set_ylabel("Image Count")
75
  ax.set_title("Class Distribution")
76
  fig.tight_layout()
 
79
 
80
  def load_datasets(api_key, file_objs):
81
  """
82
+ 1) Ensure API key present
83
+ 2) Read & dedupe URLs from each uploaded .txt
84
+ 3) Fetch & aggregate metadata
85
+ Returns: total, table_data, figure, json_counts, markdown_sources
86
  """
87
+ if not api_key or not api_key.strip():
88
+ raise ValueError("Please enter your Roboflow API Key before loading datasets.")
89
+
90
  entries = []
91
  seen = set()
92
  for fobj in file_objs:
93
  fname = getattr(fobj, "name", None) or fobj.get("name", "unknown")
94
+ # read raw bytes or dict-data or file path
95
  try:
96
  raw = fobj.read()
97
  except:
 
100
  with open(fobj, "rb") as fh:
101
  raw = fh.read()
102
  text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw
103
+
104
  for i, line in enumerate(text.splitlines(), start=1):
105
  url = line.strip()
106
  if url and url not in seen:
 
109
 
110
  total, counts, sources = aggregate_datasets(api_key, entries)
111
 
112
+ # build dataframe rows
113
+ table_data = [[cls, counts[cls]] for cls in counts]
114
 
115
+ # build clickable markdown per-class
116
  md_lines = []
117
  for cls in counts:
118
+ links = ", ".join(f"[{url.split('/')[-1]}]({url})" for url in sources[cls])
119
  md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}")
120
  md_sources = "\n".join(md_lines)
121
 
122
  fig = make_bar_chart(counts)
123
+ return str(total), table_data, fig, json.dumps(counts, indent=2), md_sources
124
 
125
 
126
  def update_classes(df_data):
127
  """
128
+ Convert df_data into a list-of-lists (if needed),
129
+ merge duplicate/lowercased class names, and recalc all outputs.
130
+ Returns: total, updated_table, figure, json_counts, markdown_summary
131
  """
132
+ # convert Pandas DataFrame or NumPy array into list-of-lists
133
+ if not isinstance(df_data, list):
134
+ if hasattr(df_data, "to_numpy"):
135
+ df_data = df_data.to_numpy().tolist()
136
+ elif hasattr(df_data, "tolist"):
137
+ df_data = df_data.tolist()
138
+
139
  combined = {}
140
+ for row in df_data:
141
+ if len(row) < 2:
142
+ continue
143
+ name, cnt = row[0], row[1]
144
  if not name:
145
  continue
146
+ key = str(name).strip().lower()
147
  try:
148
  val = int(cnt)
149
  except:
 
151
  combined[key] = combined.get(key, 0) + val
152
 
153
  total = sum(combined.values())
154
+ updated_table = [[k, combined[k]] for k in combined]
155
  fig = make_bar_chart(combined)
156
  md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined)
157
+
158
+ return str(total), updated_table, fig, json.dumps(combined, indent=2), md_summary
159
 
160
 
161
  def build_ui():
 
169
  load_btn = gr.Button("Load Datasets")
170
  total_out = gr.Textbox(label="Total Images", interactive=False)
171
  df = gr.Dataframe(
172
+ headers=["Class Name", "Count"],
173
+ row_count=(1, None),
174
+ col_count=2,
175
+ interactive=True
176
  )
177
  plot = gr.Plot()
178
  json_out = gr.Textbox(label="Counts (JSON)", interactive=False)
 
183
  load_btn.click(
184
  fn=load_datasets,
185
  inputs=[api_input, files],
186
+ outputs=[total_out, df, plot, json_out, md_out]
187
  )
188
  update_btn.click(
189
  fn=update_classes,
190
  inputs=[df],
191
+ outputs=[total_out, df, plot, json_out, md_out]
192
  )
193
 
194
  return demo