Spaces:
Sleeping
Sleeping
| import plotly | |
| import pandas as pd | |
| from umap import UMAP | |
| import plotly.express as px | |
| from bertopic import BERTopic | |
| from typing import Dict, List | |
| def custom_topic_barchart(model: BERTopic, topic_labels: Dict[int, str], top_n_topics: int = 10, n_words: int = 10) -> plotly.graph_objs.Figure: | |
| """ | |
| Create a custom horizontal bar chart of top topics using plotly.express. | |
| """ | |
| data = [] | |
| for topic_id, label in topic_labels.items(): | |
| if topic_id == -1: | |
| continue | |
| topic = model.get_topic(topic_id) | |
| if not isinstance(topic, list) or len(topic) == 0: | |
| continue | |
| for pair in topic[:n_words]: | |
| if not isinstance(pair, (list, tuple)) or len(pair) != 2: | |
| continue | |
| word, score = pair | |
| data.append({"Topic": label, "Word": word, "Score": score}) | |
| # ✅ Construct only if data exists | |
| if not data: | |
| print("[WARN] No topic-word-score data to visualize.") | |
| return plotly.graph_objs.Figure() | |
| df = pd.DataFrame(data) | |
| required_cols = {"Topic", "Word", "Score"} | |
| if not required_cols.issubset(df.columns): | |
| print("[ERROR] Required columns missing in DataFrame.") | |
| return plotly.graph_objs.Figure() | |
| fig = px.bar( | |
| df, | |
| x="Score", | |
| y="Word", | |
| color="Topic", | |
| orientation='h', | |
| barmode="group", | |
| ) | |
| fig.update_layout( | |
| margin=dict(l=40, r=20, t=40, b=20), | |
| yaxis=dict(title=""), | |
| xaxis=dict(title="Relevance"), | |
| legend_title_text="Topic", | |
| ) | |
| return fig | |
| def custom_umap_plot(embeddings: List[List[float]], topics: List[int], topic_labels: Dict[int, str]) -> plotly.graph_objs.Figure: | |
| """ | |
| Custom UMAP plotting to work better with the Gradio layout. | |
| """ | |
| # Compute a safe number of neighbours | |
| n_samples = len(embeddings) | |
| safe_n_neighbours = min(15, max(2, n_samples - 1)) | |
| reducer = UMAP(n_neighbors = safe_n_neighbours, min_dist = 0.1, metric = "cosine", random_state = 42) | |
| umap_coords = reducer.fit_transform(embeddings) | |
| df = pd.DataFrame(umap_coords, columns=["x", "y"]) | |
| df["topic"] = topics | |
| df["label"] = [topic_labels[t] for t in topics] | |
| # Filter out topic -1 (noise) | |
| df = df[df["topic"] != -1] | |
| fig = px.scatter( | |
| df, | |
| x = 'x', | |
| y = 'y', | |
| color = "label", | |
| labels = {"label": "Topic"}, | |
| #height = 500 | |
| ) | |
| fig.update_layout(margin = dict(l = 20, r = 20, t = 40, b = 20)) | |
| return fig |