import os
import sys
import time
from pathlib import Path

import numpy as np
import scipy as sp
from scipy.sparse.linalg import cg
from scipy.sparse.linalg import gmres

from pheromone_dispersion.advection_operator import Advection
from pheromone_dispersion.diffusion_operator import Diffusion
from pheromone_dispersion.identity_operator import Id
from pheromone_dispersion.reaction_operator import Reaction
from pheromone_dispersion.source_term import Source

class DiffusionConvectionReaction2DEquation:
    Class containing the pheromone propagation model given by the 2D diffusion-convection-reaction PDE:

    .. math::
        \frac{\partial c}{\partial t}-\nabla\cdot(\mathbf{K}\nabla c)+\nabla\cdot(\vec{u}c)+\tau_{loss}c=s
        ~\forall (x,y)\in\Omega~\forall t\in]0;T]

    with the initial and boundary conditions:

        - a null initial condition :math:`c(x,y,t=0)=0~\forall (x,y)\in\Omega`,
        - a null diffusive flux :math:`\mathbf{K}\nabla c\cdot\vec{n}=0~\forall (x,y)\in\partial\Omega`,
        -  null convective influx
           :math:`\vec{u}c\cdot\vec{n}=0~\forall (x,y)\in\partial\Omega\cap \{(x,y)|\vec{u}(x,y,t)\cdot\vec{n}<0\}~\forall t\in]0;T]`
           with :math:`\vec{n}` the outgoing normal vector,

    and its solvers.

    msh: ~pheromone_dispersion.geom.MeshRect2D
        The geometry of the domain.
    A: ~pheromone_dispersion.advection_operator.Advection
        The advection linear operator :math:`A:c\mapsto \nabla\cdot(\vec{u}c)~\forall (x,y)\in\Omega~\forall t\in]0;T]`
        with :math:`\vec{u}c\cdot\vec{n}=0~\forall (x,y)\in\partial\Omega\cap \{(x,y)|\vec{u}(x,y,t)\cdot\vec{n}<0\}~\forall t\in]0;T]`.
    D: ~pheromone_dispersion.diffusion_operator.Diffusion
        The diffusion linear operator :math:`D:c\mapsto -\nabla\cdot(\mathbf{K}\nabla c)~\forall (x,y)\in\Omega~\forall t\in]0;T]`
        with :math:`\mathbf{K}\nabla c\cdot\vec{n}=0~\forall (x,y)\in\partial\Omega`.
    R: ~pheromone_dispersion.reaction_operator.Reaction
        The reaction linear operator :math:`R:c\mapsto \tau_{loss}c~\forall (x,y)\in\Omega~\forall t\in]0;T]`.
    ID: ~pheromone_dispersion.identity_operator.Id
        The identity operator :math:`Id:c\mapsto c~\forall (x,y)\in\Omega~\forall t\in]0;T]`.
    S: ~pheromone_dispersion.source_term.Source
        The source term  :math:`s(x,y,t)`.
    implemented_solver_type: list of str
        The list of keywords of the implemented types of time discretization.
        The implemented types of time discretization are:
        `['implicit', 'semi-implicit', 'semi-implicit with matrix inversion', 'implicit with stationnary matrix inversion']`
    time_discretization: str
        The keyword identifying the type of time discretization.
    tol_inversion: float
        Tolerance for the inversion estimation algorithm called at each time step.
    inv_matrix_implicit_part: ~numpy.ndarray or None
        Inverse matrix of the implicit part matrix of the time discretization.
        Initialized to `None`.

    # To do:
    # - add exceptions in case
    #   the inv_matrix_implicit_part is not initialized
    #   but should have been
    # - add exceptions to check inputs of __init__

    def __init__(self, U, K, coeff_depot, S, msh, time_discretization='semi-implicit', tol_inversion=1e-14):
        Constructor method

        U: ~pheromone_dispersion.velocity.Velocity
            The wind field :math:`\vec{u}(x,y,t)`.
        K: ~pheromone_dispersion.diffusion_tensor.DiffusionTensor
            The diffusion tensor :math:`\mathbf{K}(x,y,t)`.
        coeff_depot: ~numpy.ndarray
            The deposition coefficient :math:`\tau_{loss}(x,y)`.
        S: ~pheromone_dispersion.source_term.Source
            The source term :math:`s(x,y,t)`.
        msh: ~pheromone_dispersion.geom.MeshRect2D
            The geometry of the domain.
        time_discretization: str, default: 'semi-implicit'
            The keyword identifying the type of time discretization.
        tol_inversion: float, optional, default: 1e-14
            Tolerance for the inversion estimation algorithm called at each time step.

            if the type of discretization is not implemented,
            i.e. if :attr:`time_discretization` is not in :attr:`implemented_solver_type`

        self.msh = msh
        self.A = Advection(U, msh)
        self.D = Diffusion(K, msh)
        self.R = Reaction(coeff_depot, msh)
        self.Id = Id(msh)
        self.S = S
        self.time_discretization = time_discretization
        self.tol_inversion = tol_inversion
        self.implemented_solver_type = [
            'semi-implicit with matrix inversion',
            'implicit with stationnary matrix inversion',
        if self.time_discretization not in self.implemented_solver_type:
            raise ValueError("The given time discretization is not implemented.")
        self.inv_matrix_implicit_part = None

    def init_inverse_matrix(self, path_to_matrix=None, matrix_file_name=None):
        Initialize the attribute :attr:`inv_matrix_implicit_part`.
        If the inverse matrix of the implicit part matrix of the time discretization has never been computed, it is computed and stored.
        Otherwise, it loads the previously computed inverse matrix.

        path_to_matrix: str, optional, default: None
            Path where to save or load the inverse matrix.
            If not provided, set to `'./data'`.
        matrix_file_name: str, optional, default: None
            Name of the file where the matrix stored or will be saved.
            If not provided, set to `'inv_matrix_**_scheme'`
            with ** either `'implicit'` or `'semi_implicit'` depending on the time discretization.

        This method is usefull only if :attr:`time_discretization` is
        either `'semi-implicit with matrix inversion'` or `'implicit with stationnary matrix inversion'`.
        Otherwise, the linear system to solve at each time steps is solved using conjugate gradient or GMRES algorithm,
        and the attribute :attr:`inv_matrix_implicit_part` is not used.

        # To do:
        # add exceptions in case
        # the given file name or path
        # does not exist

        if path_to_matrix is None:
            path_to_matrix = './data'
        if not os.path.isdir(path_to_matrix):
        if matrix_file_name is None:
            if self.time_discretization == 'semi-implicit with matrix inversion':
                matrix_file_name = 'inv_matrix_semi_implicit_scheme'
            if self.time_discretization == 'implicit with stationnary matrix inversion':
                matrix_file_name = 'inv_matrix_implicit_scheme'
        if matrix_file_name[-4:] != '.npy':
            matrix_file_name += '.npy'

        # Compute the inverse of the matrix of the implicit part if the file is not in the folder
        if not (Path(path_to_matrix) / matrix_file_name).exists():
            print("=== Computation of the inverse of the matrix of the implicit part of the " + self.time_discretization + " scheme ===")
            Identity = np.identity(self.msh.y.size * self.msh.x.size)

            if self.time_discretization == 'semi-implicit with matrix inversion':
                matrix_semi_implicit_scheme = Identity + self.msh.dt * (-self.D._matmat(Identity) + self.R._matmat(Identity))
                self.inv_matrix_implicit_part = sp.linalg.inv(matrix_semi_implicit_scheme)
       / matrix_file_name, self.inv_matrix_implicit_part)

            if self.time_discretization == 'implicit with stationnary matrix inversion':
                t_i = time.time()
                matrix_implicit_scheme = Identity + self.msh.dt * (
                    -self.D._matmat(Identity) + self.R._matmat(Identity) + self.A._matmat(Identity)
                self.inv_matrix_implicit_part = sp.linalg.inv(matrix_implicit_scheme)
                print("--- Computation at time t= ", self.msh.t, "in ", time.time() - t_i, " s")
                t_i = time.time()
       / matrix_file_name, self.inv_matrix_implicit_part)

        # Load the inverse of the matrix of the implicit part if the file is in the folder
            print("=== Load of the inverse of the matrix of the implicit part of the " + self.time_discretization + " scheme ===")
            if self.time_discretization == 'semi-implicit with matrix inversion':
                self.inv_matrix_implicit_part = np.load(Path(path_to_matrix) / matrix_file_name)
            if self.time_discretization == 'implicit with stationnary matrix inversion':
                self.inv_matrix_implicit_part = np.load(Path(path_to_matrix) / matrix_file_name)

    def set_source(self, value, t=None):
        Update the attribute :attr:`S` with the provided new values.

        value: ~numpy.ndarray
            The new values of :attr:`S`.
        t: ~numpy.ndarray, optional, default: None
            The associated time array. `None` if the source is stationary.

        self.S = Source(self.msh, value, t=t)

    def at_current_time(self, tc):
        Update the attributes :attr:`S`, :attr:`D` and :attr:`A` at a given time.

        tc : float or integer
            The current time.

        Updates the attributes :attr:`S`, :attr:`D` and :attr:`A` and their own attributes
        using the method :meth:`~pheromone_dispersion.source_term.Source.at_current_time` of resp.
        the class :class:`~pheromone_dispersion.source_term.Source`,
        the class :class:`~pheromone_dispersion.diffusion_operator.Diffusion` and
        the class :class:`~pheromone_dispersion.advection_operator.Advection`.

        if not self.time_discretization == 'implicit with stationnary matrix inversion':

    def solver(self, save_flag=False, path_save='./save/', display_flag=True, store_rate=1):
        Compute the concentration :math:`c(x,y,t)` by solving the pheromone propagation model on the whole time window.

        save_flag: bool, optional, default: False
            If `True`, the resulting matrix of the concentration is saved.
        path_save: str, optional, default: './save/'
            Path of the directory in which the outputs are saved.
        display_flag: bool, default: True
            If `True`, print the evolution of the solver through the time iterations.
        store_rate: int, optional, default: 1
            Time frequency to which the concentration map is stored.

        t_save: ~numpy.ndarray
            The array containing the times :math:`t` at which the concentration maps are stored.
        c_save: ~numpy.ndarray
            The concentration maps :math:`c(x,y)` at several times :math:`t`.

        # initialization of the unknown variable at the current time
        c = np.zeros((self.msh.y.shape[0] * self.msh.x.shape[0],))

        # initialization of the outputs array to be saved
        if save_flag:
            # if the save directory does not exist, then it is created
            if not os.path.isdir(path_save):
        t_save = []
        c_save = []

        # loop until the final time or the steady state is reached
        for it, self.msh.t in enumerate(self.msh.t_array[1:]):
            if display_flag:
                sys.stdout.write(f'\rt = {"{:.3f}".format(self.msh.t)} / {"{:.3f}".format(self.msh.T_final)} s')

            # update the coefficients of the equation at the current time and
            # store the concentration at the previous time step
            c_old = np.copy(c)  # NECESSARY???

            # inverse the linear system resulting the semi-implicit time discretization
            # using the pre-computed inverse matrix
            if self.time_discretization == 'semi-implicit with matrix inversion':
                c = + self.msh.dt * (-self.A.matvec(c_old) + self.S.value.ravel()))
                info = 0
            # inverse the linear system resulting the semi-implicit time discretization
            # using a conjugate gradient method for the current time step
            elif self.time_discretization == 'semi-implicit':
                c, info = cg(
                    self.Id + self.msh.dt * (-self.D + self.R),
                    c_old + self.msh.dt * (-self.A.matvec(c_old) + self.S.value.ravel()),
            # inverse the linear system resulting the implicit time discretization using a gmres method for the current time step
            elif self.time_discretization == 'implicit':
                c, info = gmres(
                    self.Id + self.msh.dt * (-self.D + self.R + self.A),
                    c_old + self.msh.dt * self.S.value.ravel(),
            # inverse the linear system resulting the implicit time discretization
            # using the pre-computed inverse matrix
            elif self.time_discretization == 'implicit with stationnary matrix inversion':
                RHS = c_old + self.msh.dt * self.S.value.ravel()
                LHS = (self.Id + self.msh.dt * (-self.D + self.R + self.A)).matvec(c_old)
                flag_residu = not np.linalg.norm(RHS - LHS) < self.tol_inversion * np.linalg.norm(RHS)
                if flag_residu:
                    c =, RHS)
                info = 0

            if info > 0:
                raise ValueError(
                    "The algorithme used to solve the linear system has not converge"
                    + "to the expected tolerance or within the maximum number of iteration."

            if info < 0:
                raise ValueError("The algorithme used to solve the linear system could not proceed du to illegal input or breakdown.")

            # store the outputs
            if it % store_rate == 0:
                c_save.append(c.reshape((self.msh.y.shape[0], self.msh.x.shape[0])))

        # save the ouputs
        if save_flag:
   / 'c_save.npy', c_save)
   / 't_save.npy', t_save)

        return np.array(t_save), np.array(c_save)

    def solver_est_at_obs_times(self, obs, display_flag=True):
        Compute the concentration :math:`c(x,y,t)` by solving the pheromone propagation model on the whole time window
        and store the results at the times and positions required to estimate the observed variable
        in the attribute :attr:`~source_localization.obs.Obs.c_est` of the given object of the class :class:`~source_localization.obs.Obs`.

        obs: ~source_localization.obs.Obs
            Object containing all the features related to the observations and estimation of the observed variables
            and in which :math:`c(x,y,t)` is stored at the times and positions required to estimate the observed variable.
        display_flag: bool, default: True
            If `True`, print the evolution of the solver through the time iterations.

        # initialization of the unknown variable at the current time
        c = np.zeros((self.msh.y.shape[0] * self.msh.x.shape[0],))

        if 0 in obs.index_obs_to_index_time_est:
            c_prov = c.reshape((self.msh.y.shape[0], self.msh.x.shape[0]))
            for index_obs in obs.index_time_est_to_index_obs[0]:
                index_x_est = np.argmin(np.abs(self.msh.x - obs.X_obs[index_obs, 0]))
                index_y_est = np.argmin(np.abs(self.msh.y - obs.X_obs[index_obs, 1]))
                obs.c_est[index_obs] = c_prov[index_y_est, index_x_est]

        # loop until the final time or the steady state is reached
        for it, self.msh.t in enumerate(self.msh.t_array[1:]):
            if display_flag:
                sys.stdout.write(f'\rt = {"{:.3f}".format(self.msh.t)} / {"{:.3f}".format(self.msh.T_final)} s')
            # update the coefficients of the equation at the current time and
            # store the concentration at the previous time step
            c_old = np.copy(c)  # NECESSARY???

            # inverse the linear system resulting the semi-implicit time discretization
            # using the pre-computed inverse matrix
            if self.time_discretization == 'semi-implicit with matrix inversion':
                c = + self.msh.dt * (-self.A.matvec(c_old) + self.S.value.ravel()))
                info = 0
            # inverse the linear system resulting the semi-implicit time discretization
            # using a conjugate gradient method for the current time step
            elif self.time_discretization == 'semi-implicit':
                c, info = cg(
                    self.Id + self.msh.dt * (-self.D + self.R),
                    c_old + self.msh.dt * (-self.A.matvec(c_old) + self.S.value.ravel()),
            # inverse the linear system resulting the implicit time discretization using a gmres method for the current time step
            elif self.time_discretization == 'implicit':
                c, info = gmres(
                    self.Id + self.msh.dt * (-self.D + self.R + self.A),
                    c_old + self.msh.dt * self.S.value.ravel(),
            # inverse the linear system resulting the implicit time discretization
            # using the pre-computed inverse matrix
            elif self.time_discretization == 'implicit with stationnary matrix inversion':
                RHS = c_old + self.msh.dt * self.S.value.ravel()
                if not np.linalg.norm(RHS, ord=np.inf) < 1e-16:
                    LHS = (self.Id + self.msh.dt * (-self.D + self.R + self.A)).matvec(c_old)
                    flag_residu = not np.linalg.norm(RHS - LHS) < self.tol_inversion * np.linalg.norm(RHS)
                    # print(np.linalg.norm(RHS),np.linalg.norm(RHS, ord=np.inf))
                    if flag_residu:
                        c =, RHS)
                    c = np.zeros_like(c)  # print("the norm is 0")
                # c =, c_old + self.msh.dt * self.S.value.ravel())
                info = 0

            if info > 0:
                raise ValueError(
                    "The algorithme used to solve the linear system has not converge"
                    + "to the expected tolerance or within the maximum number of iteration."

            if info < 0:
                raise ValueError("The algorithme used to solve the linear system could not proceed du to illegal input or breakdown.")

            it += 1
            # store the output in obs.c_est if necessary
            if it in obs.index_obs_to_index_time_est:
                c_prov = c.reshape((self.msh.y.shape[0], self.msh.x.shape[0]))
                for index_obs in obs.index_time_est_to_index_obs[it]:
                    i = np.where(obs.index_obs_to_index_time_est[index_obs, :] == it)
                    index_x_est = np.argmin(np.abs(self.msh.x - obs.X_obs[index_obs, 0]))
                    index_y_est = np.argmin(np.abs(self.msh.y - obs.X_obs[index_obs, 1]))
                    obs.c_est[index_obs, i] = c_prov[index_y_est, index_x_est]