# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    import conv1d_cpp
except:
    pass
from .utils import column_split


def canonicalize_modal_system(poles, residues):
    """Canonicalize a modal system.

    Args:
        poles (Tensor): The poles of the system.
        residues (Tensor): The residues of the system.

    Returns:
        Tuple[Tensor, Tensor]: The canonicalized poles and residues.
    """
    raise NotImplementedError


IIR_PREFILL_MODES = [
    "recurrence",
    "modal-fft",
    "hybrid-modal-recurrence",
    "modal-scan",
    "canonical-fft",
    "iir-fir-caching",
]


class HyenaInferenceEngine:
    def __init__(
        self, fir_fn=None, fftconv_fn=None, iir_prefill_style="modal-fft", layer_idx=None
    ) -> None:
        self.fir_fn = fir_fn
        self.fftconv_fn = fftconv_fn
        assert (
            iir_prefill_style in IIR_PREFILL_MODES
        ), f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
        self.iir_prefill_style = iir_prefill_style
        self.layer_idx = layer_idx
        self.low_mem_mode = False

    def parallel_fir(
        self,
        fir_fn,
        u,
        weight,
        bias,
        L,
        fir_length=3,
        inference_params=None,
        prefill_mode=None,
        padding_mask=None,
    ):
        """Compute the output state of the long convolutional filter."""
        # prepare input layout, dimensions and dispatch to fir kernel
        if fir_fn != torch.nn.functional.conv1d:
            z_pre = fir_fn(u)[:, :L]  # B, L, D
            z_pre = z_pre.permute(0, 2, 1)
        else:
            u = u.permute(0, 2, 1)  # B, D, L
            z_pre = fir_fn(
                u,
                weight,
                bias,
                stride=1,
                padding=fir_length - 1,
                groups=u.shape[1],
            )[..., :L]

        # handle padding post fir, the only place with biases
        if type(padding_mask) == torch.Tensor:
            z_pre = z_pre * padding_mask[:, None]

        if inference_params is not None:
            # handle seqlen last and dim last cases for `u`
            if fir_fn != torch.nn.functional.conv1d:
                fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
            else:
                fir_state = u[..., -fir_length + 1 :]
        else:
            fir_state = None

        return z_pre, fir_state

    def parallel_iir(
        self,
        z_pre,
        h,
        D,
        L,
        poles,
        t,
        dims,
        layer_idx,
        inference_params=None,
        prefill_style="fft",
        fftconv_fn=None,
        padding_mask=None,
        use_flashfft=False,
        column_split_hyena=False,
        long_fir_threshold=None,
    ):
        """Compute the output state of the short convolutional filter."""
        fft_size = 2 * L
        hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
        # Compatibility with training infra that column splits the projections
        if column_split_hyena:
            z = z_pre.reshape(
                z_pre.shape[0],
                num_attention_heads,
                3 * hidden_size_per_attention_head,
                z_pre.shape[2],
            )
            x2, x1, v = (
                z[:, :, :hidden_size_per_attention_head],
                z[
                    :,
                    :,
                    hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
                ],
                z[:, :, 2 * hidden_size_per_attention_head :],
            )
            x2, x1, v = (
                x2.reshape(x2.shape[0], -1, x2.shape[-1]),
                x1.reshape(x1.shape[0], -1, x1.shape[-1]),
                v.reshape(v.shape[0], -1, v.shape[-1]),
            )
        else:
            x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)

        x1v = x1 * v

        if use_flashfft and (L % 2) == 0:  # only works with even L
            y = fftconv_fn(
                x1v.to(dtype=torch.bfloat16).contiguous(),
                h.to(dtype=torch.float32),
            )
            X_s = None

        elif long_fir_threshold is None:
            H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
            X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
            X = X_s[..., : H.shape[-1]]
            if len(z_pre.shape) > 3:
                H = H.unsqueeze(1)
            y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
        else:
            assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
            h = h[0][:, None]  # rearrange to d, 1, l for depthwise conv1d
            h = h[..., :long_fir_threshold]
            y = F.conv1d(
                x1v,
                h.to(dtype=x1v.dtype),
                stride=1,
                groups=x1v.shape[1],
                padding=h.shape[-1] - 1,
            )[..., :L]

        y = y.to(dtype=x1v.dtype)
        y = (y + x1v * D.unsqueeze(-1)) * x2
        if inference_params is not None:
            if prefill_style == "fft":
                self.prefill_via_modal_fft(
                    inference_params=inference_params,
                    x1v=x1v,
                    X_s=X_s,
                    L=L,
                    t=t,
                    poles=poles,
                    dims=dims,
                    layer_idx=layer_idx,
                    use_flashfft=use_flashfft,
                )

            elif prefill_style == "recurrence":
                self.prefill_via_direct_recurrence(
                    inference_params=inference_params,
                    x1v=x1v,
                    L=L,
                    poles=poles,
                )

            else:
                raise NotImplementedError
            if self.low_mem_mode:
                del z_pre, x2, x1, v, x1v, h
                torch.cuda.empty_cache()

        return y.permute(0, 2, 1)

    def step_fir(self, u, fir_state, weight, bias=None):
        """Step the FIR filter.

        Note:
        `fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
        We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
        """
        h0, h = weight[..., 0, -1], weight[..., 0, :-1]
        h0, h = h0[None], h[None]
        y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias

        # update
        fir_state = torch.roll(fir_state, -1, dims=2)
        fir_state[..., -1] = u
        return y, fir_state

    def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
        x1v = x1 * v

        residues, poles = (
            torch.view_as_complex(residues.to(torch.float32)),
            torch.view_as_complex(poles.to(torch.float32)),
        )
        # squeeze the dummy seqlen dimension
        # D, state_dim, 1 -> 1, D, state_dim
        residues, poles = residues[..., 0][None], poles[..., 0][None]
        iir_state = poles * iir_state + x1v[..., None]

        res_state = torch.sum(residues * iir_state, dim=-1).real

        if iir_groups > 1:
            raise NotImplementedError
        y = x2 * (res_state + D * x1v)

        return y, iir_state

    def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
        """Turns the IIR filter into a FIR and uses a cache for decoding."""
        raise NotImplementedError(":)")

    def prefill_via_direct_recurrence(self, inference_params, x1v, L, poles, *args, **kwargs):
        """
        Compute the IIR state via explicit SSM recurrence (modal form)
        """
        x1v_ = x1v[..., None, None]  # b, d, l, sdim, reim
        x1v_ = x1v_.repeat(1, 1, 1, 1, 2)  # b, d, l, sdim, reim

        state = x1v_[:, :, 0]
        poles = poles[:, :, 0].to(dtype=torch.float32)

        for i in range(L):
            state = poles * state + x1v_[:, :, i]
        inference_params.state_dict[self.layer_idx] = torch.view_as_complex(
            state.to(dtype=torch.float32)
        )

    def prefill_via_hybrid_recurrence(
        self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs
    ):
        """
        Compute the IIR state via hybrid recurrence-convolution over blocks
        """
        raise NotImplementedError(":)")

    def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
        raise NotImplementedError

    def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
        """
        Compute the IIR state via a single FFT with the denominator of the SSM in companion form.

        This is the most memory efficient "parallelized" prefilling method for Hyena.

        From: https://arxiv.org/abs/2310.18780
        """
        raise NotImplementedError(":)")

    def prefill_via_modal_fft(
        self,
        inference_params,
        x1v,
        L,
        poles,
        t,
        dims,
        layer_idx,
        X_s=None,
        use_flashfft=False,
        state_dtype=torch.complex64,
        *args,
        **kwargs,
    ):
        """
        Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
        """
        # When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
        # we split the filter into poles and residues and reuse FFT computation on the input.
        # This optimization is currently not supported when using flashfftconv.
        hidden_size, _, _, state_size, hyena_filter_groups = dims

        if use_flashfft:
            # using real states
            poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]

            state_s = poles**t
            if hyena_filter_groups > 1:
                raise NotImplementedError

            x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
            x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
            state_s = state_s[None]

            state = self.fftconv_fn(
                x1v.contiguous(),
                state_s.to(dtype=torch.float32),
            )
            state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
            state = torch.view_as_complex(state.contiguous())
            inference_params.state_dict[self.layer_idx] = state.to(dtype=state_dtype)
        else:
            assert X_s is not None
            bs = x1v.shape[0]
            fft_size = 2 * L
            poles = torch.view_as_complex(poles.to(torch.float32))
            state_s = poles**t
            state_S = torch.fft.fft(state_s, n=fft_size).repeat(
                bs, 1, 1, 1
            )  # B, D, state_dim, 2 * L
            if hyena_filter_groups > 1:
                state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
            state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
            inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)

    def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
        """
        Compute the IIR state given an input `u` and log_poles of the modal system.
        """
        bs = u.shape[0]
        fft_size = 2 * L
        U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
        fft_size = 2 * L
        x = (log_poles * t).exp()
        # [batch, hidden_size, state_dim, 2 * seqlen]
        X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
        state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
        return state