from typing import Optional

import numpy as np
import os

import os
import sys
from contextlib import contextmanager


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
@contextmanager
def suppress_stderr():
    fd = sys.stderr.fileno()
    def _redirect_stderr(to):
        sys.stderr.close()  # + implicit flush()
        os.dup2(to.fileno(), fd)  # fd writes to 'to' file
        sys.stderr = os.fdopen(fd, 'w')  # Python function to write to fd

    with os.fdopen(os.dup(fd), 'w') as old_stderr:
        with open(os.devnull, 'w') as file:
            _redirect_stderr(to=file)
        try:
            yield  # allow code to be run with the redirected stderr
        finally:
            _redirect_stderr(to=old_stderr)  # restore stderr

# Использование:
with suppress_stderr():
    import tensorflow as tf

tf.config.set_visible_devices([], 'GPU')
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)

import hdbscan
import cv2
from scipy.interpolate import interp1d

from src.spectral_analysis.peakers.base_peaker import Peaker
from src.spectral_analysis.models import Spectra
from src import root_dir


np.random.seed(42)
SHAPE = 256


class PeakerAE(Peaker):
    """Picking dispersion curves using a neural network."""
    @staticmethod
    def _segments_fit(x_test: np.ndarray, y_test: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """
        Fit segments by averaging y-values for each unique x-value.

        Args:
            x_test: Array of x-coordinates
            y_test: Array of y-coordinates corresponding to x_test

        Returns:
            tuple: Unique x-values and their corresponding averaged y-values
        """
        x_uniq = np.unique(x_test)
        y_uniq = []
        for x in x_uniq:
            condition = x_test == x
            y_uniq.append(np.mean(y_test[condition]))
        y_uniq = np.asarray(y_uniq)
        return x_uniq, y_uniq

    @staticmethod
    def _get_coeffs(x: np.ndarray, y: np.ndarray) -> tuple[float, float]:
        """
        Calculate linear regression coefficients for given x and y data.

        Args:
            x: Independent variable values (frequency indices)
            y: Dependent variable values (velocity values)

        Returns:
            tuple: Intercept (b) and slope (a) coefficients of linear fit y = a*x + b
        """
        A = np.array([np.ones_like(x), x]).T
        A_inv = np.linalg.pinv(A)
        coeffs = A_inv @ y
        return coeffs[0], coeffs[1]

    @staticmethod
    def _remove_outbreaks(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """
        Remove outliers from data using HDBSCAN clustering.

        Identifies the largest cluster and returns only points belonging to that cluster,
        effectively removing outlier points.

        Args:
            data: 2D array of data points to filter (shape: n_samples, 2)

        Returns:
            tuple: Filtered x and y coordinates of the largest cluster
        """
        labels = hdbscan.HDBSCAN(min_samples=4, core_dist_n_jobs=1).fit_predict(data)

        uniq_labels = np.unique(labels)

        max_x_range = -1
        best_cluster_label = None

        for label in uniq_labels:
            if label == -1:
                continue

            cluster_points = data[labels == label]
            if len(cluster_points) > 0:
                x_range = len(np.unique(cluster_points[:, 0]))

                if x_range > max_x_range:
                    max_x_range = x_range
                    best_cluster_label = label

        if best_cluster_label is not None:
            x = data[labels == best_cluster_label, 0]
            y = data[labels == best_cluster_label, 1]
        else:
            x = data[:, 0]
            y = data[:, 1]

        return x, y

    @staticmethod
    def _get_curve(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        """
        Extract a dispersion curve from spectral data.

        Processes input data to identify the dominant curve by:
        1. Thresholding to keep only high-amplitude points (top 5%)
        2. Clustering to remove outliers
        3. Linear fitting and statistical outlier removal (3σ criterion)
        4. Segment averaging

        Args:
            data: 2D spectral data array (velocity-frequency matrix)

        Returns:
            tuple: Processed x (frequency indices) and y (velocity indices) coordinates of the curve
        """
        y_size = data.shape[0]
        x = np.arange(data.shape[1])
        y = np.arange(data.shape[0])
        xx, yy = np.meshgrid(x, y)

        condition = data >= np.max(data)*0.95
        x_ = xx[condition]
        y_ = yy[condition]
        data_for_clust = np.array([x_, y_]).T

        indexes = data_for_clust[:, 1] > round(y_size/SHAPE)
        data_for_clust_new = np.copy(data_for_clust)
        data_for_clust_new[:, 1][indexes] = data_for_clust_new[:, 1][indexes] - round(y_size/SHAPE)

        x_, y_ = PeakerAE._remove_outbreaks(data_for_clust_new)
        b, a = PeakerAE._get_coeffs(x_, y_)
        std = 3 * np.std(y_ - (a * x_ + b))

        trend = a * x_ + b
        condition = ((trend - 3 * std) < y_) * ((trend + 3 * std) > y_)

        x__ = x_[condition]
        y__ = y_[condition]

        px, py = PeakerAE._segments_fit(x__, y__)

        return px, py

    @staticmethod
    def _image_resize(img, target_size=(SHAPE, SHAPE)):
        """
        Resize image to target dimensions using area interpolation.

        Args:
            img: Input image array
            target_size: Target dimensions (width, height)

        Returns:
            np.ndarray: Resized image
        """
        reshaped_img = cv2.resize(
            img.astype(np.float32),
            target_size,
            interpolation=cv2.INTER_AREA
        )
        return reshaped_img

    def predict_mask(self, vf_spectra: np.ndarray) -> np.ndarray:
        """
        Predict segmentation masks from velocity-frequency spectra using the neural network.

        Args:
            vf_spectra: 2D velocity-frequency spectral data

        Returns:
            np.ndarray: 4-channel segmentation mask resized to original dimensions
        """
        main_shape = vf_spectra.shape
        vf_spectra4predict = PeakerAE._image_resize(vf_spectra, target_size=(SHAPE, SHAPE))

        input_tensor = tf.convert_to_tensor(
            vf_spectra4predict.reshape(1, SHAPE, SHAPE, 1),
            dtype=tf.float32
        )

        mask = self.encoder_model(input_tensor, training=False).numpy()

        result_mask = []
        for i in range(4):
            result_mask.append(PeakerAE._image_resize(mask[0, :, :, i], target_size=(main_shape[1], main_shape[0])))
        return np.array(result_mask)

    def load_model(self) -> None:
        """
        Load the pre-trained Autoencoder model from disk.

        Initializes the model architecture and prepares it for inference.
        Loads model from 'src/spectral_analysis/peakers/AE_model.keras'
        """
        self.encoder_model = tf.keras.models.load_model(root_dir / "src/spectral_analysis/peakers/AE_model.keras", compile=False)
        dummy_input = tf.zeros((1, 256, 256, 1))
        _ = self.encoder_model(dummy_input)


    @staticmethod
    def _peak_high_mode(px_tmp: np.ndarray, py_tmp: np.ndarray, f_indx_spectrum: np.ndarray,
                       shift: np.ndarray, min_vel: np.ndarray, mask: np.ndarray) \
            -> Optional[tuple[np.ndarray, np.ndarray]]:
        """
        Extract higher mode dispersion curve from spectral data.

        This method processes the fundamental mode curve to identify and extract
        potential higher mode dispersion curves by:
        1. Extrapolating the fundamental mode trend
        2. Creating a search region above the fundamental mode
        3. Applying curve extraction to the higher amplitude regions

        Args:
            px_tmp: Frequency indices of the fundamental mode curve
            py_tmp: Velocity indices of the fundamental mode curve
            f_indx_spectrum: Frequency indices of the entire spectrum
            shift: Frequency shift value for index adjustment
            min_vel: Minimum velocity value for coordinate conversion
            mask: Segmentation mask for higher mode detection (channel 2)

        Returns:
            Optional[tuple]:
                - px_high: Frequency indices of the higher mode curve (shift-adjusted)
                - py_high: Velocity values of the higher mode curve (velocity coordinates)
                Returns None if no valid higher mode is detected
        """
        sort_indexes = np.argsort(px_tmp)
        py_extrapolated = interp1d(
            px_tmp[sort_indexes], py_tmp[sort_indexes],
            kind='linear', bounds_error=False, fill_value=(py_tmp[sort_indexes][0], py_tmp[sort_indexes][-1])
        )(f_indx_spectrum)

        low_border = np.int32(np.array(py_extrapolated) + np.array(py_extrapolated) * 0.3)
        high_mode_mask = np.zeros_like(mask)
        max_border = high_mode_mask.shape[0]
        for idx, freq_tmp_idx in enumerate(low_border):
            high_mode_mask[:, idx + shift][
            min(freq_tmp_idx, max_border):
            ] = mask[:, idx + shift][
                min(freq_tmp_idx, max_border):
                ]

        if np.max(high_mode_mask) > 0:
            px_tmp_high, py_tmp_high = PeakerAE._get_curve(high_mode_mask)
            sort_indexes = np.argsort(px_tmp_high)
            return px_tmp_high[sort_indexes] - shift, py_tmp_high[sort_indexes] + min_vel
        else:
            return None

    def peak_dc(self, spectra: Spectra, peak_fraction: float, cutoff_fraction: float) \
            -> tuple[list, list, np.ndarray, np.ndarray, np.ndarray, list]:
        """
        Main method to extract dispersion curves from spectral data.

        Processes input spectra to extract both fundamental and higher mode
        dispersion curves using neural network segmentation and validation.

        Args:
            spectra: Spectra object containing velocity-frequency data
            peak_fraction: Fraction of maximum amplitude for thresholding (0-1)
            cutoff_fraction: Fraction for frequency limit calculations (unused parameter)

        Returns:
            tuple: Contains:
                - List of frequency arrays for each mode
                - List of velocity arrays for each mode
                - Frequency limits array
                - Lower velocity limits array
                - Upper velocity limits array
                - List of amplitude arrays for each mode
                - Segmentation mask array
        """
        freq, upper_limits, lower_limits, f_indx_spectrum = self._find_limits(spectra)
        shift = f_indx_spectrum[0]
        min_vel = spectra.velocities[0]
        freq_limits = np.copy(freq)

        mask_tmp = np.copy(spectra.vf_spectra)
        for idx in range(len(freq)):
            freq_slice = mask_tmp[:, idx + shift]
            freq_slice[freq_slice <= np.max(freq_slice) * peak_fraction] = 0
            mask_tmp[:, idx + shift] = freq_slice

        mask = self.predict_mask(mask_tmp)
        mask_tmp = np.zeros_like(mask)
        mask_tmp[1::2, :, :] = mask[1::2, :, :]
        for idx in range(len(freq)):
            mask_tmp[
            0, int(lower_limits[idx]-min_vel) : int(upper_limits[idx]-min_vel), idx + shift
            ] = mask[
                0, int(lower_limits[idx]-min_vel) : int(upper_limits[idx]-min_vel), idx + shift
                ]

            mask_tmp[
            2, int(lower_limits[idx] - min_vel): int(upper_limits[idx] - min_vel), idx + shift
            ] = mask[
                2, int(lower_limits[idx] - min_vel): int(upper_limits[idx] - min_vel), idx + shift
                ]

        mask = mask_tmp
        px, py, ampl = [], [], []
        try:
            px_tmp, py_tmp = self._get_curve(mask[0, :, :])
            sort_indexes = np.argsort(px_tmp)

            px.append(px_tmp[sort_indexes] - shift)
            py.append(py_tmp[sort_indexes] + min_vel)
            ampl.append(spectra.vf_spectra[np.int32(py_tmp), np.int32(px_tmp)])

            if np.max(mask[2, :, :]) > 0:
                first_mode = PeakerAE._peak_high_mode(
                    px_tmp,
                    py_tmp,
                    f_indx_spectrum,
                    shift,
                    min_vel,
                    mask[2, :, :]
                )
                if first_mode is not None:
                    px.append(first_mode[0])
                    py.append(first_mode[1])
                    ampl.append(
                        spectra.vf_spectra[
                            np.int32(first_mode[1]-min_vel), np.int32(first_mode[0]+shift)
                        ]
                    )
        except Exception as e:
            mask_tmp = np.zeros_like(spectra.vf_spectra)
            for idx in range(len(freq)):
                mask_tmp[
                int(lower_limits[idx]-min_vel) : int(upper_limits[idx]-min_vel), idx + shift
                ] = spectra.vf_spectra[
                    int(lower_limits[idx]-min_vel) : int(upper_limits[idx]-min_vel), idx + shift
                    ]
            mask = self.predict_mask(mask_tmp)
            px_tmp, py_tmp = self._get_curve(mask[0, :, :])
            sort_indexes = np.argsort(px_tmp)
            px.append(px_tmp[sort_indexes] - shift)
            py.append(py_tmp[sort_indexes] + min_vel)
            ampl.append(spectra.vf_spectra[np.int32(py_tmp), np.int32(px_tmp)])

            if np.max(mask[2, :, :]) > 0:
                first_mode = PeakerAE._peak_high_mode(
                    px[0],
                    py[0],
                    f_indx_spectrum,
                    shift,
                    min_vel,
                    mask[2, :, :]
                )
                if first_mode is not None:
                    px.append(first_mode[0])
                    py.append(first_mode[1])
                    ampl.append(
                        spectra.vf_spectra[
                            np.int32(first_mode[1] - min_vel), np.int32(first_mode[0] + shift)
                        ]
                    )
        freq_all = [freq[px_tmp] for px_tmp in px]
        return freq_all, py, freq_limits, lower_limits, upper_limits, ampl