/*******************************************************************************
* Copyright 2019-2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

#pragma once

#include <complex>
#include <vector>

// Wrappers to for compile-time checking of whether we are in
// Real or Complex space
// Complex helpers.
template <typename T>
struct complex_info {
    using real_type              = T;
    static const bool is_complex = false;
};

template <typename T>
struct complex_info<std::complex<T>> {
    using real_type              = T;
    static const bool is_complex = true;
};

template <class T> struct is_complex : std::false_type {};
template <class T> struct is_complex<std::complex<T>> : std::true_type {};

/* Template-specialized conjugate() for conjtrans operations to avoid build
 * errors */
template <typename T> static inline T conjugate(T t) {
    throw("Invalid type for conjugate(); only float/double/complex<float>/complex<double> supported");
}
template <> inline float conjugate(float t) {return t;}
template <> inline double conjugate(double t) {return t;}
template <> inline std::complex<float> conjugate(std::complex<float> t) {return std::conj(t);}
template <> inline std::complex<double> conjugate(std::complex<double> t) {return std::conj(t);}

template <typename T> inline T opVal(const T t, const bool isConj) {return (isConj?conjugate(t):t);}

template <typename T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
inline T ceil_div(T num, T denom) {
    assert(denom > 0);
    return (num + denom - 1) / denom;
}

template <typename fp, typename intType>
void cleanup_arrays(std::vector<fp *> &fp_ptr_vec,
                    std::vector<intType *> &int_ptr_vec,
                    sycl::context cxt)
{
    for (int i = 0; i < fp_ptr_vec.size(); i++) {
        sycl::free(fp_ptr_vec[i], cxt);
    }

    for (int i = 0; i < int_ptr_vec.size(); i++) {
        sycl::free(int_ptr_vec[i], cxt);
    }
}

template <typename fp, typename intType>
void cleanup_arrays(std::vector<fp *> &fp_ptr_vec,
                    std::vector<intType *> &int_ptr_vec,
                    std::vector<std::int64_t *> &i64_ptr_vec,
                    std::vector<void *> &void_ptr_vec,
                    sycl::context cxt)
{
    for (int i = 0; i < fp_ptr_vec.size(); i++) {
        sycl::free(fp_ptr_vec[i], cxt);
    }

    for (int i = 0; i < int_ptr_vec.size(); i++) {
        sycl::free(int_ptr_vec[i], cxt);
    }

    for (int i = 0; i < i64_ptr_vec.size(); i++) {
        sycl::free(i64_ptr_vec[i], cxt);
    }

    for (int i = 0; i < void_ptr_vec.size(); i++) {
        sycl::free(void_ptr_vec[i], cxt);
    }
}

// Creating the 3arrays CSR representation (ia, ja, values)
// of stencil-based matrix with size nx=ny=nz
template <typename fp, typename intType>
void generate_sparse_matrix(const intType nx,
                            std::vector<intType, mkl_allocator<intType, 64>> &ia,
                            std::vector<intType, mkl_allocator<intType, 64>> &ja,
                            std::vector<fp, mkl_allocator<fp, 64>> &a,
                            const intType index = 0)
{
    intType nz = nx, ny = nx;
    intType nnz = 0;
    intType current_row;

    ia[0] = index;

    for (intType iz = 0; iz < nz; iz++) {
        for (intType iy = 0; iy < ny; iy++) {
            for (intType ix = 0; ix < nx; ix++) {

                current_row = iz * nx * ny + iy * nx + ix;

                for (intType sz = -1; sz <= 1; sz++) {
                    if (iz + sz > -1 && iz + sz < nz) {
                        for (intType sy = -1; sy <= 1; sy++) {
                            if (iy + sy > -1 && iy + sy < ny) {
                                for (intType sx = -1; sx <= 1; sx++) {
                                    if (ix + sx > -1 && ix + sx < nx) {
                                        intType current_column =
                                                current_row + sz * nx * ny + sy * nx + sx;
                                        ja[nnz] = current_column + index;
                                        if (current_column == current_row) {
                                            a[nnz++] = set_fp_value(fp(26.0), fp(0.0));
                                        }
                                        else {
                                            a[nnz++] = set_fp_value(fp(-1.0), fp(0.0));
                                        }
                                    } // end
                                      // x
                                      // bounds
                                      // test
                                }     // end sx loop
                            }         // end y bounds test
                        }             // end sy loop
                    }                 // end z bounds test
                }                     // end sz loop
                ia[current_row + 1] = nnz + index;

            } // end ix loop
        }     // end iy loop
    }         // end iz loop
}

template <typename fp, typename intType>
void generate_sparse_matrix(const intType nx, 
                                  intType *ia, 
                                  intType *ja, 
                                  fp *a,
                            const intType index = 0)
{
    intType nz = nx, ny = nx;
    intType nnz = 0;
    intType current_row;

    ia[0] = index;

    for (intType iz = 0; iz < nz; iz++) {
        for (intType iy = 0; iy < ny; iy++) {
            for (intType ix = 0; ix < nx; ix++) {

                current_row = iz * nx * ny + iy * nx + ix;

                for (intType sz = -1; sz <= 1; sz++) {
                    if (iz + sz > -1 && iz + sz < nz) {
                        for (intType sy = -1; sy <= 1; sy++) {
                            if (iy + sy > -1 && iy + sy < ny) {
                                for (intType sx = -1; sx <= 1; sx++) {
                                    if (ix + sx > -1 && ix + sx < nx) {
                                        intType current_column =
                                                current_row + sz * nx * ny + sy * nx + sx;
                                        ja[nnz] = current_column + index;
                                        if (current_column == current_row) {
                                            a[nnz++] = set_fp_value(fp(26.0), fp(0.0));
                                        }
                                        else {
                                            a[nnz++] = set_fp_value(fp(-1.0), fp(0.0));
                                        }
                                    } // end
                                      // x
                                      // bounds
                                      // test
                                }     // end sx loop
                            }         // end y bounds test
                        }             // end sy loop
                    }                 // end z bounds test
                }                     // end sz loop
                ia[current_row + 1] = nnz + index;

            } // end ix loop
        }     // end iy loop
    }         // end iz loop
}

// Creating the 3arrays CSR representation (ia, ja, values)
// of general random sparse matrix
// with density (0 < density <= 1.0)
// -0.5 <= value < 0.5
template <typename fp, typename intType>
void generate_random_sparse_matrix(const intType nrows,
                                   const intType ncols,
                                   const double density_val,
                                   std::vector<intType, mkl_allocator<intType, 64>> &ia,
                                   std::vector<intType, mkl_allocator<intType, 64>> &ja,
                                   std::vector<fp, mkl_allocator<fp, 64>> &a,
                                   const intType index = 0)
{
    intType nnz = 0;
    ia.push_back(0 + index); // starting index of row0.

    for (intType i = 0; i < nrows; i++) {
        ia.push_back(nnz + index); // ending index of row_i.
        for (intType j = 0; j < ncols; j++) {
            if ((double)std::rand() / RAND_MAX < density_val) {
                a.push_back(rand_scalar<fp>());
                ja.push_back(j + index);
                nnz++;
            }
        }

        ia[i + 1] = nnz + index; // update ending index of row_i.
    }
}

// Shuffle the 3arrays CSR representation (ia, ja, values)
// of any sparse matrix.
template <typename fp, typename intType>
void shuffle_matrix_data(const intType *ia,
                         intType *ja,
                         fp *a,
                         const intType indexing,
                         const intType nrows,
                         const intType nnz)
{
    //
    // shuffle indices according to random seed
    //
    for (intType i = 0; i < nrows; ++i) {
        intType nnz_row = ia[i+1]-ia[i];
        for (intType j = ia[i]-indexing; j < ia[i+1]-indexing; ++j) {
            intType q = ia[i]-indexing + std::rand() % (nnz_row);
            // swap element i and q
            std::swap(ja[q], ja[j]);
            std::swap(a[q], a[j]);
        }
    }
}

// Shuffle the 3arrays CSR representation (ia, ja, values)
// of any sparse matrix.
template <typename fp, typename intType>
void shuffle_matrix_data(const std::vector<intType, mkl_allocator<intType, 64>> &ia,
                         std::vector<intType, mkl_allocator<intType, 64>> &ja,
                         std::vector<fp, mkl_allocator<fp, 64>> &a,
                         const intType indexing,
                         const intType nrows,
                         const intType nnz)
{
    //
    // shuffle indices according to random seed
    //
    for (intType i = 0; i < nrows; ++i) {
        intType nnz_row = ia[i+1]-ia[i];
        for (intType j = ia[i]-indexing; j < ia[i+1]-indexing; ++j) {
            intType q = ia[i]-indexing + std::rand() % (nnz_row);
            // swap element i and q
            std::swap(ja[q], ja[j]);
            std::swap(a[q], a[j]);
        }
    }
}

template <typename fp, typename fp_real>
bool check_errors(fp x, fp x_ref, double bound)
{
    fp_real aerr = std::abs(x - x_ref);
    fp_real rerr = aerr / (std::abs(x_ref) + std::numeric_limits<fp_real>::epsilon());
    bool ok  = (rerr <= bound) || (aerr <= bound);
    if (!ok)
        std::cout << "relative error = " << rerr << " absolute error = " << aerr
                  << " limit = " << bound;
    return ok;
}

template <typename fp, typename intType>
bool check_result(fp res, fp ref, intType index)
{
    bool check;
    using fp_real = typename complex_info<fp>::real_type;
    fp_real bound = std::numeric_limits<fp_real>::epsilon();
    check = check_errors<fp, fp_real>(res, ref, bound);
    if (!check)
        std::cout << " in index: " << index << std::endl;
    return check;
}

template <typename fp, typename intType>
bool check_result(fp res, fp ref, intType nFlops, intType index)
{
    bool check;
    using fp_real = typename complex_info<fp>::real_type;
    fp_real bound = std::numeric_limits<fp_real>::epsilon() * static_cast<double>(nFlops);
    check = check_errors<fp, fp_real>(res, ref, bound);
    if (!check)
        std::cout << " in index: " << index << std::endl;
    return check;
}

std::ostream& operator<< ( std::ostream& out, oneapi::mkl::sparse::matrix_view_descr view )
{
    using oneapi::mkl::sparse::matrix_view_descr;
    switch(view)
    {
        case matrix_view_descr::general : return out << "matrix_view_desc::general" ;
        default: return out << "matrix_view_descr{" << int(view) << "}" ;
    }
}

std::ostream& operator<< ( std::ostream& out, oneapi::mkl::sparse::matmat_request req )
{
    using oneapi::mkl::sparse::matmat_request;
    switch(req)
    {
        case matmat_request::get_work_estimation_buf_size : return out << "matmat_request::get_work_estimation_buf_size" ;
        case matmat_request::work_estimation : return out << "matmat_request::work_estimation" ;
        case matmat_request::get_compute_structure_buf_size : return out << "matmat_request::get_compute_structure_buf_size" ;
        case matmat_request::compute_structure : return out << "matmat_request::compute_structure" ;
        case matmat_request::finalize_structure : return out << "matmat_request::finalize_structure" ;
        case matmat_request::get_compute_buf_size : return out << "matmat_request::get_compute_buf_size" ;
        case matmat_request::compute : return out << "matmat_request::compute" ;
        case matmat_request::get_nnz : return out << "matmat_request::get_nnz" ;
        case matmat_request::finalize : return out << "matmat_request::finalize" ;
        default: return out << "matmat_request{" << int(req) << "}" ;
    }
}


