# Vectorial Astigmatic Gaussian Beam

import numpy as np
from qosm.utils.Field import Field
from qosm.utils.Pose import Vector

import warnings
warnings.filterwarnings('ignore')


class VAGB:
    """Vectorial Astigmatic Gaussian Beam

    Attributes
    ---------
    k0 : float
        Free-space wave number
    n: Real index of refraction
    kappa: Extinction coefficient
    z0 : [float, float]
            [meter] waist position offset for each axis w.r.t the beam origin (z=0)
    alpha: [complex, complex]
        Complex polarisation coefficient for each axis
    q1 : complex
        [complex meters] 1/Q11 parameter
    q2 : complex
        [complex meters] 1/Q22 parameter
    theta : complex
        Complex angle of the beam
    a0 : complex
        Complex initial amplitude
    phi0: float
        [rad] Initial phase
    z_range : [float, float]
            [meters] minimal and maximal values of z in local beam's frame
    """

    def __init__(self,
                 polarisation: [complex, complex] = (1, 0),
                 k0: float = 0.0,
                 ior: [float, float] = (1, 0),
                 z0: [float, float] = (0, 0),
                 initial_phase: float = 0,
                 initial_amplitude: float = 1,
                 w0: [float, float] = None,
                 w0_list: np.ndarray = None,
                 q: [complex, complex] = None,
                 curvature_matrix: np.ndarray = None,
                 z_limits: [float, float] = (float('-inf'), float('inf'))):
        """
        Parameters
        ----------
        polarisation: [float, float]
            Complex polarisation coefficients for each axis
        k0 : float
            Free-space wave number
        ior : [float, float]
            Complex index of refraction : ior[0] - j ior[1].
        z0 : [float, float]
            [meter] waist position offset for each axis w.r.t the beam origin (z=0)
        initial_amplitude: float
            [V/m] Initial amplitude of the scalar Gaussian Function
        initial_phase: float
            [rad] Initial phase of the scalar Gaussian Function
        w0 : [float, float]
            [meter] waist size at z=z0, for each axis
        w0_list : np.ndarray((N, 2), dtype=float)
            [meter] multi-beams case - waist size of each beam (dim 0) and each axis (dim 1), at z=0
        curvature_matrix: np.ndarray((2, 2))
            Curvature matrix of the beam. By-pass any value given for w0 or w0_list
        z_limits: [float, float]
            [meters] Optional, Min and Max values of z to be considered (spatial limitation of the beam)
        """

        self.z0 = np.array(z0)
        self.k0 = k0
        self.n = ior[0]
        self.kappa = ior[1]
        self.Z = (Field.Z0 / (np.array(ior[0]) - np.array(ior[1]) * 1j)).reshape(-1, 1)
        self.theta = 0
        self.q1 = None
        self.q2 = None
        self.a0 = initial_amplitude
        self.phi0 = initial_phase
        self.z_range = z_limits
        self.alpha = polarisation

        if curvature_matrix is not None:
            Qi = curvature_matrix
            # Compute eigenvalues of Q
            t = Qi[0, 0] + Qi[1, 1]
            d = Qi[0, 0] * Qi[1, 1] - Qi[0, 1] * Qi[1, 0]
            ev = [.5 * t + (.25 * t ** 2 - d) ** .5, .5 * t - (.25 * t ** 2 - d) ** .5]
            g = -np.real(ev) / np.imag(ev)
            if np.abs(Qi[0, 1] + Qi[1, 0]) >= 1e-10:
                # GAGB
                self.theta = 0.5 * np.arctan((Qi[0, 1] + Qi[1, 0]) / (Qi[0, 0] - Qi[1, 1]))
                self.q1 = 1 / ev[0] + self.z0[0]
                self.q2 = 1 / ev[1] + self.z0[1]
            else:
                # SAGB
                self.q1 = 1 / Qi[0, 0] - self.z0[0]
                self.q2 = 1 / Qi[1, 1] - self.z0[1]

            # [m] Beam's origins wrt the pt where QI is defined
            if np.prod(np.real([1 / self.q1, 1 / self.q2]) * (1 + g ** 2)) != 0:
                self.z0 = -g ** 2 / (np.real([1 / self.q1, 1 / self.q2]) * (1 + g ** 2))

        elif q is not None:
            # q1, q2 -> complex
            self.q1 = q[0]
            self.q2 = q[1]

        elif w0 is not None:
            k = k0 * (ior[0] - ior[1]*1j)
            # q1, q2 -> complex
            self.q1 = - (k * w0[0] ** 2) / 2j - self.z0[0]
            self.q2 = - (k * w0[1] ** 2) / 2j - self.z0[1]

        elif w0_list is not None:
            k = k0 * (ior[0] - ior[1]*1j)
            # q1, q2 -> complex[N,]
            self.q1 = - np.reshape((k * w0_list[:, 0] ** 2) / 2j - self.z0[0], (-1,))
            self.q2 = - np.reshape((k * w0_list[:, 1] ** 2) / 2j - self.z0[1], (-1,))

    @property
    def impedance(self) -> complex:
        """Complex wave impedance in the medium
        """
        return self.Z

    @property
    def k(self) -> complex:
        """Complex wave number in the medium
        """
        return self.k0 * (self.n - self.kappa * 1j)

    @property
    def ior(self) -> tuple[float, float]:
        """Complex index of refraction

        Returns
        -------
        (float, float)
            The real index of refraction and the extinction coefficient
        """
        return self.n, self.kappa

    def set_z_limits(self, z_min, z_max) -> None:
        """Set spatial limitations of the beam

        Parameters
        ----------
        z_min : float
            [meters] minimal value of z in local beam's frame
        z_max : float
            [meters] maximum value of z in local beam's frame
        """

        self.z_range = [z_min, z_max]

    def init(self, Qi, a_i=None, phi_i=0) -> None:
        # Compute eigenvalues of Q
        t = Qi[0, 0] + Qi[1, 1]
        d = Qi[0, 0] * Qi[1, 1] - Qi[0, 1] * Qi[1, 0]
        ev = [.5 * t + (.25 * t ** 2 - d) ** .5, .5 * t - (.25 * t ** 2 - d) ** .5]
        g = -np.real(ev) / np.imag(ev)
        if np.abs(Qi[0, 1] + Qi[1, 0]) >= 1e-10:
            # GAGB
            self.theta = 0.5 * np.arctan((Qi[0, 1] + Qi[1, 0]) / (Qi[0, 0] - Qi[1, 1]))
            self.q1 = 1 / ev[0]
            self.q2 = 1 / ev[1]
        else:
            # SAGB
            self.q1 = 1 / Qi[0, 0]
            self.q2 = 1 / Qi[1, 1]

        # [m] Beam's origins wrt the pt where QI is defined
        if np.prod(np.real([1/self.q1, 1/self.q2]) * (1 + g ** 2)) != 0:
            self.z0 = -g ** 2 / (np.real([1 / self.q1, 1 / self.q2]) * (1 + g ** 2))
        else:
            self.z0 = np.array([0, 0])

        # matching term
        if a_i is not None:
            # power-normalised
            kr = self.k0 * self.n
            g1 = np.imag(Qi[0, 0])
            g2 = np.imag(Qi[1, 1])
            g0 = np.imag(Qi[0, 1] + Qi[1, 0])

            a_0 = np.sqrt(.5 * self.k0 * Field.Z0 * np.sqrt(np.abs(4 * g1 * g2 - g0 ** 2)) / np.pi)
            self.a0 = a_i / a_0
            eta_0 = .5 * (
                    np.arctan(np.real(self.q1) / np.imag(self.q1))
                    + np.arctan(np.real(self.q2) / np.imag(self.q2))
            )
            self.phi0 = phi_i - eta_0
        else:
            self.a0 = 1

    def u(self, r_pts_beam: Vector) -> (np.ndarray, np.ndarray):
        """
        Compute the Local (scalar) Gaussian Function at the specified points in space

        Parameters
        ----------
        r_pts_beam : Vector
            [meters] Vector of points [Nx3] to be used, written in beam's frame

        Raises
        ------
        Exception
            if an overflow occurs during exp computation

        Returns
        -------
        (np.ndarray, np.ndarray)
            ([N, ], [2, 2, N])
            Values of the Gaussian Function at each point and
            Complex curvature matrix of the beam at each point
        """
        N = r_pts_beam.shape[0]
        x_ = np.reshape(r_pts_beam[:, 0], (-1,)).view(np.ndarray)
        y_ = np.reshape(r_pts_beam[:, 1], (-1,)).view(np.ndarray)
        z_ = np.reshape(r_pts_beam[:, 2], (-1,)).view(np.ndarray)

        q1_ = self.q1 + z_
        q2_ = self.q2 + z_

        # Curvature matrix
        Q_ = np.zeros((2, 2, N), dtype=np.complex128)
        if self.theta != 0:
            Q_[0, 0, :] = np.cos(self.theta) ** 2 / q1_ + np.sin(self.theta) ** 2 / q2_
            Q_[0, 1, :] = Q_[1, 0, :] = .5 * np.sin(2 * self.theta) * (1. / q1_ - 1. / q2_)
            Q_[1, 1, :] = np.sin(self.theta) ** 2 / q1_ + np.cos(self.theta) ** 2 / q2_
        else:
            Q_[0, 0, :] = 1 / q1_
            Q_[1, 1, :] = 1 / q2_

        # Amplitude normalisation
        g1 = np.imag(Q_[0, 0, :])
        g2 = np.imag(Q_[1, 1, :])
        g0 = np.imag(Q_[0, 1, :] + Q_[1, 0, :])

        k = self.k0 * self.n

        # G = 4 * g1 * g2 - g0 ** 2
        # np.sqrt( .5 * self._k * np.sqrt(G) / np.pi)
        # @todo understand why G can be negative
        G = np.abs(4*g1 * g2 - g0 ** 2)
        u0_ = np.sqrt(.5 * self.k0 * Field.Z0 * np.sqrt(G) / np.pi)

        # Gouy phase
        eta = .5 * (
                np.arctan(np.real(q1_) / np.imag(q1_))
                + np.arctan(np.real(q2_) / np.imag(q2_))
        )

        absorption = - self.kappa * self.k0 * z_

        # Accumulated phase (propagation)
        phi_ac_ = - k * z_

        # complex phase term
        phase = eta + phi_ac_ + self.phi0

        # [x y]*Qz*[x;y] -> quadratic part
        """W_ = np.imag(Q_)
        C_ = np.imag(Q_)
        A_ = W_ - kappa_over_n * C_
        P_ = C_ + kappa_over_n * W_
        rAr = A_[0, 0, :] * x_ ** 2 + A_[1, 1, :] * y_ ** 2 + (A_[0, 1, :] + A_[1, 0, :]) * x_ * y_
        rPr = P_[0, 0, :] * x_ ** 2 + P_[1, 1, :] * y_ ** 2 + (P_[0, 1, :] + P_[1, 0, :]) * x_ * y_
        """
        rQr = Q_[0, 0, :] * x_ ** 2 + Q_[1, 1, :] * y_ ** 2 + (Q_[0, 1, :] + Q_[1, 0, :]) * x_ * y_

        # Scalar Gaussian Function
        K = - .5j * self.k0 * (self.n - self.kappa * 1j)
        part1 = np.exp(K * rQr + absorption + 1j * phase)
        part1[np.isinf(part1)] = 0
        u_ = u0_ * self.a0 * part1
        u_[np.abs(u_) < 1e-7] = 0.
        return u_, Q_

    def compute_fields(self, r_pts_beam: Vector) -> (Field, Field):
        """Compute E and H Field at each point of space

        Parameters
        ----------
        r_pts_beam: Vector
            [meters] Vector of points [Nx3] to be used, written in beam's frame

        Returns
        -------
        (Field, Field)
            E and H complex fields computed for each point ([Nx3], [Nx3]), written in beam's frame

        """

        x_ = r_pts_beam[:, 0]
        y_ = r_pts_beam[:, 1]
        z_ = r_pts_beam[:, 2]

        mask = np.zeros(z_.shape)
        mask[z_ < self.z_range[0]] = 1
        mask[z_ > self.z_range[1]] = 1

        u, Q_  = self.u(r_pts_beam)
        # print('CHECK: ', r_, '\n     ', np.abs(u), '\n     ', np.angle(u))
        u = u * (1 - mask)

        Qd_ = .5*(Q_[0, 1, :] + Q_[1, 0, :])
        dux = (Q_[0, 0, :] * x_ + Qd_ * y_).view(Field)
        duy = (Q_[1, 1, :] * y_ + Qd_ * x_).view(Field)

        # Electric field
        E = Field(N=z_.shape[0])
        E[:, 0] = self.alpha[0] * u
        E[:, 1] = self.alpha[1] * u
        E[:, 2] = - (self.alpha[0] * dux + self.alpha[1] * duy) * u

        # Magnetic field
        H = Field(N=z_.shape[0])
        H[:, 0] = - self.alpha[1] * u
        H[:, 1] = self.alpha[0] * u
        H[:, 2] = (self.alpha[1] * dux - self.alpha[0] * duy) * u
        H /= self.Z

        return E, H