# Copyright (c) Meta Platforms, Inc. and affiliates.

from collections import Counter, defaultdict
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go

from .parser import (
    filter_area,
    filter_node,
    filter_way,
    match_to_group,
    parse_area,
    parse_node,
    parse_way,
    Patterns,
)
from .reader import OSMData


def recover_hierarchy(counter: Counter) -> Dict:
    """Recover a two-level hierarchy from the flat group labels."""
    groups = defaultdict(dict)
    for k, v in sorted(counter.items(), key=lambda x: -x[1]):
        if ":" in k:
            prefix, group = k.split(":")
            if prefix in groups and isinstance(groups[prefix], int):
                groups[prefix] = {}
                groups[prefix][prefix] = groups[prefix]
                groups[prefix] = {}
            groups[prefix][group] = v
        else:
            groups[k] = v
    return dict(groups)


def bar_autolabel(rects, fontsize):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        width = rect.get_width()
        plt.gca().annotate(
            f"{width}",
            xy=(width, rect.get_y() + rect.get_height() / 2),
            xytext=(3, 0),  # 3 points vertical offset
            textcoords="offset points",
            ha="left",
            va="center",
            fontsize=fontsize,
        )


def plot_histogram(counts, fontsize, dpi):
    fig, ax = plt.subplots(dpi=dpi, figsize=(8, 20))

    labels = []
    for k, v in counts.items():
        if isinstance(v, dict):
            labels += list(v.keys())
            v = list(v.values())
        else:
            labels.append(k)
            v = [v]
        bars = plt.barh(
            len(labels) + -len(v) + np.arange(len(v)), v, height=0.9, label=k
        )
        bar_autolabel(bars, fontsize)

    ax.set_yticklabels(labels, fontsize=fontsize)
    ax.axes.xaxis.set_ticklabels([])
    ax.xaxis.tick_top()
    ax.invert_yaxis()
    plt.yticks(np.arange(len(labels)))
    plt.xscale("log")
    plt.legend(ncol=len(counts), loc="upper center")


def count_elements(elems: Dict[int, str], filter_fn, parse_fn) -> Dict:
    """Count the number of elements in each group."""
    counts = Counter()
    for elem in filter(filter_fn, elems.values()):
        group = parse_fn(elem.tags)
        if group is None:
            continue
        counts[group] += 1
    counts = recover_hierarchy(counts)
    return counts


def plot_osm_histograms(osm: OSMData, fontsize=8, dpi=150):
    counts = count_elements(osm.nodes, filter_node, parse_node)
    plot_histogram(counts, fontsize, dpi)
    plt.title("nodes")

    counts = count_elements(osm.ways, filter_way, parse_way)
    plot_histogram(counts, fontsize, dpi)
    plt.title("ways")

    counts = count_elements(osm.ways, filter_area, parse_area)
    plot_histogram(counts, fontsize, dpi)
    plt.title("areas")


def plot_sankey_hierarchy(osm: OSMData):
    triplets = []
    for node in filter(filter_node, osm.nodes.values()):
        label = parse_node(node.tags)
        if label is None:
            continue
        group = match_to_group(label, Patterns.nodes)
        if group is None:
            group = match_to_group(label, Patterns.ways)
        if group is None:
            group = "null"
        if ":" in label:
            key, tag = label.split(":")
            if tag == "yes":
                tag = key
        else:
            key = tag = label
        triplets.append((key, tag, group))
    keys, tags, groups = list(zip(*triplets))
    counts_key_tag = Counter(zip(keys, tags))
    counts_key_tag_group = Counter(triplets)

    key2tags = defaultdict(set)
    for k, t in zip(keys, tags):
        key2tags[k].add(t)
    key2tags = {k: sorted(t) for k, t in key2tags.items()}
    keytag2group = dict(zip(zip(keys, tags), groups))
    key_names = sorted(set(keys))
    tag_names = [(k, t) for k in key_names for t in key2tags[k]]

    group_names = []
    for k in key_names:
        for t in key2tags[k]:
            g = keytag2group[k, t]
            if g not in group_names and g != "null":
                group_names.append(g)
    group_names += ["null"]

    key2idx = dict(zip(key_names, range(len(key_names))))
    tag2idx = {kt: i + len(key2idx) for i, kt in enumerate(tag_names)}
    group2idx = {n: i + len(key2idx) + len(tag2idx) for i, n in enumerate(group_names)}

    key_counts = Counter(keys)
    key_text = [f"{k} {key_counts[k]}" for k in key_names]
    tag_counts = Counter(list(zip(keys, tags)))
    tag_text = [f"{t} {tag_counts[k, t]}" for k, t in tag_names]
    group_counts = Counter(groups)
    group_text = [f"{k} {group_counts[k]}" for k in group_names]

    fig = go.Figure(
        data=[
            go.Sankey(
                orientation="h",
                node=dict(
                    pad=15,
                    thickness=20,
                    line=dict(color="black", width=0.5),
                    label=key_text + tag_text + group_text,
                    x=[0] * len(key_names)
                    + [1] * len(tag_names)
                    + [2] * len(group_names),
                    color="blue",
                ),
                arrangement="fixed",
                link=dict(
                    source=[key2idx[k] for k, _ in counts_key_tag]
                    + [tag2idx[k, t] for k, t, _ in counts_key_tag_group],
                    target=[tag2idx[k, t] for k, t in counts_key_tag]
                    + [group2idx[g] for _, _, g in counts_key_tag_group],
                    value=list(counts_key_tag.values())
                    + list(counts_key_tag_group.values()),
                ),
            )
        ]
    )
    fig.update_layout(autosize=False, width=800, height=2000, font_size=10)
    fig.show()
    return fig