import re import json import requests import matplotlib.pyplot as plt import gradio as gr from requests.exceptions import HTTPError def parse_roboflow_url(url): """ Extract workspace/project and version from a Roboflow Universe URL. Returns (workspace, project, version) """ pattern = r"roboflow\.com/([^/]+)/([^/]+)/dataset/(\d+)" match = re.search(pattern, url) if not match: raise ValueError(f"Invalid Roboflow dataset URL: {url}") return match.groups() def fetch_metadata(api_key, workspace, project, version): """ Fetch metadata from Roboflow. Raises ValueError on HTTP errors. """ endpoint = f"https://api.roboflow.com/{workspace}/{project}/{version}" resp = requests.get(endpoint, params={"api_key": api_key}) try: resp.raise_for_status() except HTTPError: if resp.status_code == 401: raise ValueError("Unauthorized: check your API key.") else: raise ValueError(f"Error {resp.status_code} for {workspace}/{project}/{version}") data = resp.json() total = data.get("version", {}).get("images") or data.get("project", {}).get("images", 0) classes = data.get("project", {}).get("classes", {}) return total, classes def aggregate_datasets(api_key, entries): """ Given list of (url, file, line), returns: - total_images - dict[class_name_lowercase] = aggregated count - dict[class_name_lowercase] = set(source URLs) """ total_images = 0 class_counts = {} class_sources = {} for url, fname, lineno in entries: try: ws, proj, ver = parse_roboflow_url(url) except ValueError: raise ValueError(f"Invalid URL '{url}' in file '{fname}', line {lineno}") imgs, cls_map = fetch_metadata(api_key, ws, proj, ver) total_images += imgs for cls, cnt in cls_map.items(): norm = cls.strip().lower() class_counts[norm] = class_counts.get(norm, 0) + cnt class_sources.setdefault(norm, set()).add(url) return total_images, class_counts, class_sources def make_bar_chart(counts): """ Build a bar chart from a {label: value} dict. """ fig, ax = plt.subplots() labels = list(counts.keys()) values = list(counts.values()) ax.bar(range(len(labels)), values) ax.set_xticks(range(len(labels))) ax.set_xticklabels(labels, rotation=45, ha="right") ax.set_ylabel("Image Count") ax.set_title("Class Distribution") fig.tight_layout() return fig def load_datasets(api_key, file_objs): """ 1) Ensure API key present 2) Read & dedupe URLs from each uploaded .txt 3) Fetch & aggregate metadata Returns: total, table_data, figure, json_counts, markdown_sources """ if not api_key or not api_key.strip(): raise ValueError("Please enter your Roboflow API Key before loading datasets.") entries = [] seen = set() for fobj in file_objs: fname = getattr(fobj, "name", None) or fobj.get("name", "unknown") # read raw bytes or dict-data or file path try: raw = fobj.read() except: raw = fobj.get("data") if isinstance(fobj, dict) else None if raw is None and isinstance(fobj, str): with open(fobj, "rb") as fh: raw = fh.read() text = raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else raw for i, line in enumerate(text.splitlines(), start=1): url = line.strip() if url and url not in seen: seen.add(url) entries.append((url, fname, i)) total, counts, sources = aggregate_datasets(api_key, entries) # build dataframe rows table_data = [[cls, counts[cls]] for cls in counts] # build clickable markdown per-class md_lines = [] for cls in counts: links = ", ".join(f"[{url.split('/')[-1]}]({url})" for url in sources[cls]) md_lines.append(f"- **{cls}** ({counts[cls]} images): {links}") md_sources = "\n".join(md_lines) fig = make_bar_chart(counts) return str(total), table_data, fig, json.dumps(counts, indent=2), md_sources def update_classes(df_data): """ Convert df_data into a list-of-lists (if needed), merge duplicate/lowercased class names, and recalc all outputs. Returns: total, updated_table, figure, json_counts, markdown_summary """ # convert Pandas DataFrame or NumPy array into list-of-lists if not isinstance(df_data, list): if hasattr(df_data, "to_numpy"): df_data = df_data.to_numpy().tolist() elif hasattr(df_data, "tolist"): df_data = df_data.tolist() combined = {} for row in df_data: if len(row) < 2: continue name, cnt = row[0], row[1] if not name: continue key = str(name).strip().lower() try: val = int(cnt) except: val = 0 combined[key] = combined.get(key, 0) + val total = sum(combined.values()) updated_table = [[k, combined[k]] for k in combined] fig = make_bar_chart(combined) md_summary = "\n".join(f"- **{k}** ({combined[k]} images)" for k in combined) return str(total), updated_table, fig, json.dumps(combined, indent=2), md_summary def build_ui(): with gr.Blocks() as demo: gr.Markdown("## Roboflow Dataset Inspector") with gr.Row(): api_input = gr.Textbox(label="API Key", type="password") files = gr.Files(label="Upload .txt files", file_types=[".txt"]) load_btn = gr.Button("Load Datasets") total_out = gr.Textbox(label="Total Images", interactive=False) df = gr.Dataframe( headers=["Class Name", "Count"], row_count=(1, None), col_count=2, interactive=True ) plot = gr.Plot() json_out = gr.Textbox(label="Counts (JSON)", interactive=False) md_out = gr.Markdown(label="Class Sources") update_btn = gr.Button("Apply Class Edits") load_btn.click( fn=load_datasets, inputs=[api_input, files], outputs=[total_out, df, plot, json_out, md_out] ) update_btn.click( fn=update_classes, inputs=[df], outputs=[total_out, df, plot, json_out, md_out] ) return demo if __name__ == "__main__": build_ui().launch()