/*******************************************************************************
* Copyright 2024 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       DPCPP USM API oneapi::mkl::sparse::matmatd to perform general
*       sparse matrix-sparse matrix multiplication with a dense result matrix
*       on a SYCL device (GPU). This example uses sparse matrices in CSR format.
*
*           C = alpha * op(A) * op(B) + beta * C
*
*       where op() is defined by one of
*           oneapi::mkl::transpose::{nontrans,trans,conjtrans}
*
*       The supported floating point data types for matmatd matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*       The supported matrix formats for matmatd are:
*           CSR
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

//
// Main example for Sparse Matrix-Sparse Matrix Multiply with a dense result,
// consisting of
// initialization of A and B matrices through process of creating C matrix as
// the product
//
// C = alpha op(A) * op(B) + beta C
//
// In this case, the square symmetric matrices A and B are generated separately
// occupying different memory, but with the same structure and values, so that
// in the end C is the square of A.
//
template <typename fpType, typename intType>
int run_sparse_matrix_sparse_matrix_dense_result_multiply_example(const sycl::device &dev)
{
    // Initialize data for Sparse Matrix - Sparse Matrix Multiply
    oneapi::mkl::transpose opA = oneapi::mkl::transpose::nontrans;
    oneapi::mkl::transpose opB = oneapi::mkl::transpose::nontrans;

    oneapi::mkl::index_base a_index = oneapi::mkl::index_base::zero;
    oneapi::mkl::index_base b_index = oneapi::mkl::index_base::one;

    // Catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const &e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (sycl::exception const &e) {
                std::cout << "Caught asynchronous SYCL "
                             "exception during sparse::matmatd example:\n"
                          << e.what() << std::endl;
            }
        }
    };

    // create execution main_queue
    sycl::queue main_queue(dev, exception_handler);
    sycl::context cxt = main_queue.get_context();

    //
    // set up dimensions of matrix products
    //
    intType size = 4;

    intType a_nrows = size * size * size;
    intType a_ncols = a_nrows;
    intType a_nnz   = 27 * a_nrows;
    intType b_nrows = size * size * size;
    intType b_ncols = b_nrows;
    intType b_nnz   = 27 * b_nrows;
    intType c_nrows = size * size * size;
    intType c_ncols = c_nrows;

    //
    // array memory management tools
    //
    std::vector<intType *> int_ptr_vec;
    std::vector<fpType *> fp_ptr_vec;
    std::vector<std::int64_t *> i64_ptr_vec;
    std::vector<void *> void_ptr_vec;

    //
    // setup A data locally in CSR format
    //
    intType *a_rowptr_host = sycl::malloc_host<intType>(a_nrows + 1, main_queue);
    intType *a_colind_host = sycl::malloc_host<intType>(a_nnz, main_queue);
    fpType *a_values_host  = sycl::malloc_host<fpType>(a_nnz, main_queue);
    if (!a_rowptr_host || !a_colind_host || !a_values_host)
        throw std::runtime_error("Failed to allocate USM memory");

    int_ptr_vec.push_back(a_rowptr_host);
    int_ptr_vec.push_back(a_colind_host);
    fp_ptr_vec.push_back(a_values_host);

    intType a_ind = a_index == oneapi::mkl::index_base::zero ? 0 : 1;
    generate_sparse_matrix<fpType, intType>(size, a_rowptr_host, a_colind_host, a_values_host, a_ind);
    a_nnz = a_rowptr_host[a_nrows] - a_ind;

    intType *a_rowptr = sycl::malloc_device<intType>(a_nrows + 1, main_queue);
    intType *a_colind = sycl::malloc_device<intType>(a_nnz, main_queue);
    fpType *a_values  = sycl::malloc_device<fpType>(a_nnz, main_queue);

    if (!a_rowptr || !a_colind || !a_values)
        throw std::runtime_error("Failed to allocate USM memory");

    int_ptr_vec.push_back(a_rowptr);
    int_ptr_vec.push_back(a_colind);
    fp_ptr_vec.push_back(a_values);

    // copy A matrix USM data from host to device
    auto ev_cpy_ia = main_queue.copy<intType>(a_rowptr_host, a_rowptr, a_nrows + 1);
    auto ev_cpy_ja = main_queue.copy<intType>(a_colind_host, a_colind, a_nnz);
    auto ev_cpy_a  = main_queue.copy<fpType>(a_values_host, a_values, a_nnz);

    //
    // setup B data locally in CSR format
    //
    intType *b_rowptr_host = sycl::malloc_host<intType>(b_nrows + 1, main_queue);
    intType *b_colind_host = sycl::malloc_host<intType>(b_nnz, main_queue);
    fpType *b_values_host  = sycl::malloc_host<fpType>(b_nnz, main_queue);
    if (!b_rowptr_host || !b_colind_host || !b_values_host)
        throw std::runtime_error("Failed to allocate USM memory");

    int_ptr_vec.push_back(b_rowptr_host);
    int_ptr_vec.push_back(b_colind_host);
    fp_ptr_vec.push_back(b_values_host);

    intType b_ind = b_index == oneapi::mkl::index_base::zero ? 0 : 1;
    generate_sparse_matrix<fpType, intType>(size, b_rowptr_host, b_colind_host, b_values_host, b_ind);
    b_nnz = b_rowptr_host[b_nrows] - b_ind;

    intType *b_rowptr = sycl::malloc_device<intType>(b_nrows + 1, main_queue);
    intType *b_colind = sycl::malloc_device<intType>(b_nnz, main_queue);
    fpType *b_values  = sycl::malloc_device<fpType>(b_nnz, main_queue);

    //
    // setup C data locally in dense matrix format
    //
    fpType alpha = fpType(1);
    fpType beta = fpType(0);
    auto c_layout = oneapi::mkl::layout::row_major;
    intType ldc = (c_layout == oneapi::mkl::layout::row_major) ? c_ncols : c_nrows;
    intType c_size = (c_layout == oneapi::mkl::layout::row_major) ? ldc * c_nrows : ldc * c_ncols;
    auto c_values = sycl::malloc_device<fpType>(c_size, main_queue);
    
    if (!a_rowptr || !b_colind || !b_values || !c_values)
        throw std::runtime_error("Failed to allocate USM memory");

    int_ptr_vec.push_back(b_rowptr);
    int_ptr_vec.push_back(b_colind);
    fp_ptr_vec.push_back(b_values);
    fp_ptr_vec.push_back(c_values);
    
    // copy B matrix USM data from host to device
    auto ev_cpy_ib = main_queue.copy<intType>(b_rowptr_host, b_rowptr, b_nrows + 1);
    auto ev_cpy_jb = main_queue.copy<intType>(b_colind_host, b_colind, b_nnz);
    auto ev_cpy_b  = main_queue.copy<fpType>(b_values_host, b_values, b_nnz);

    //
    // Execute Matrix Multiply
    //

    std::cout << "\n\t\tsparse::matmatd parameters:\n";
    std::cout << "\t\t\topA = " << opA << std::endl;
    std::cout << "\t\t\topB = " << opB << std::endl;

    std::cout << "\t\t\tA_nrows = A_ncols = " << a_nrows << std::endl;
    std::cout << "\t\t\tB_nrows = B_ncols = " << b_nrows << std::endl;
    std::cout << "\t\t\tC_nrows = C_ncols = " << c_nrows << std::endl;

    std::cout << "\t\t\tA_index = " << a_index << std::endl;
    std::cout << "\t\t\tB_index = " << b_index << std::endl;

    oneapi::mkl::sparse::matrix_handle_t A = nullptr;
    oneapi::mkl::sparse::matrix_handle_t B = nullptr;

    try {
        oneapi::mkl::sparse::init_matrix_handle(&A);
        oneapi::mkl::sparse::init_matrix_handle(&B);

        auto ev_setA = oneapi::mkl::sparse::set_csr_data(main_queue, A, a_nrows, a_ncols, a_index,
                a_rowptr, a_colind, a_values, {ev_cpy_ia, ev_cpy_ja, ev_cpy_a});
        auto ev_setB = oneapi::mkl::sparse::set_csr_data(main_queue, B, b_nrows, b_ncols, b_index,
                b_rowptr, b_colind, b_values, {ev_cpy_ib, ev_cpy_jb, ev_cpy_b});
        auto ev_fillC = main_queue.fill<fpType>(c_values, fpType(0), c_size);

        //
        // compute sparse matrix-matrix product with dense result
        //
        auto ev2 = oneapi::mkl::sparse::matmatd(
            main_queue, c_layout, opA, opB, alpha, A, B, beta, c_values, c_nrows, c_ncols,
            ldc, {ev_setA, ev_setB, ev_fillC});

        // Copy C matrix to host for printing
        const intType c_nrows_copy = std::min<intType>(2, c_nrows); // only copy over this many rows of C to host

        fpType * c_values_host = sycl::malloc_host<fpType>(c_ncols * c_nrows_copy, main_queue);
        if (!c_values_host)
            throw std::runtime_error("Failed to allocate USM memory");
        fp_ptr_vec.push_back(c_values_host);

        auto ev_copy_c = main_queue.submit([&](sycl::handler& cgh) {
            cgh.depends_on(ev2);
            if (c_layout == oneapi::mkl::layout::row_major) {
                cgh.parallel_for(
                    sycl::range<2>(c_nrows_copy, c_ncols),
                    [=](sycl::item<2> item) {
                        c_values_host[c_ncols * item.get_id(0) + item.get_id(1)] =
                            c_values[ldc * item.get_id(0) + item.get_id(1)];
                    });
            }
            else if (c_layout == oneapi::mkl::layout::col_major) {
                cgh.parallel_for(
                    sycl::range<2>(c_ncols, c_nrows_copy),
                    [=](sycl::item<2> item) {
                        c_values_host[c_nrows * item.get_id(0) + item.get_id(1)] =
                            c_values[ldc * item.get_id(0) + item.get_id(1)];
                    });
            }
            else {
                throw std::runtime_error("Bad layout for C amtrix");
            }
        });
        ev_copy_c.wait(); // make sure copy is done before reading from it

        // print out a portion of C solution
        sycl::event ev_print = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on({ev_copy_c});
            auto kernel = [=]() {
                std::cout << "C matrix [first " << c_nrows_copy << " rows]:" << std::endl;
                for (intType row = 0; row < c_nrows_copy; ++row) {
                    for (intType col = 0; col < c_ncols; ++col) {
                        fpType val = (c_layout == oneapi::mkl::layout::row_major) ?
                            c_values_host[ldc*row + col] : c_values_host[ldc*col + row];
                        std::cout << "C(" << row << ", " << col << ") = " << val
                                  << std::endl;
                    }
                }
            };
            cgh.host_task(kernel);
        });

        // clean up
        auto ev_relA = oneapi::mkl::sparse::release_matrix_handle(main_queue, &A, {ev_print});
        auto ev_relB = oneapi::mkl::sparse::release_matrix_handle(main_queue, &B, {ev_print});

        ev_relA.wait();
        ev_relB.wait();

    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;

        main_queue.wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &A).wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &B).wait();

        cleanup_arrays<fpType, intType>(fp_ptr_vec, int_ptr_vec, i64_ptr_vec, void_ptr_vec, cxt);

        return 1;
    }
    catch (oneapi::mkl::unimplemented const &e) {
        std::cout << "\t\tCaught oneMKL unimplemented exception:\n" << e.what() << std::endl;

        main_queue.wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &A).wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &B).wait();

        cleanup_arrays<fpType, intType>(fp_ptr_vec, int_ptr_vec, i64_ptr_vec, void_ptr_vec, cxt);

        return 0;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;

        main_queue.wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &A).wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &B).wait();

        cleanup_arrays<fpType, intType>(fp_ptr_vec, int_ptr_vec, i64_ptr_vec, void_ptr_vec, cxt);

        return 1;
    }

    cleanup_arrays<fpType, intType>(fp_ptr_vec, int_ptr_vec, i64_ptr_vec, void_ptr_vec, cxt);

    return 0;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Matrix-Sparse Matrix Multiply with dense result Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#    C = alpha * op(A) * op(B) + beta * C" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A and B are sparse matrices in CSR format, and C is the\n"
                 "# dense matrix product"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::matmatd" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#   sparse::init_matrix_handle" << std::endl;
    std::cout << "#   sparse::set_csr_data" << std::endl;
    std::cout << "#   sparse::release_matrix_handle" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type, MatrixMultiplyExample
//  runs is with all supported data types
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {

        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            std::cout << "\tRunning with single precision real data type:" << std::endl;
            status = run_sparse_matrix_sparse_matrix_dense_result_multiply_example<float, std::int32_t>(my_dev);
            if (status != 0)
                return status;

            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                std::cout << "\tRunning with double precision real data type:" << std::endl;
                status = run_sparse_matrix_sparse_matrix_dense_result_multiply_example<double, std::int32_t>(
                        my_dev);
                if (status != 0)
                    return status;
            }

            std::cout << "\tRunning with single precision complex data type:" << std::endl;
            status = run_sparse_matrix_sparse_matrix_dense_result_multiply_example<std::complex<float>, std::int32_t>(my_dev);
            if (status != 0)
                return status;

            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                std::cout << "\tRunning with double precision complex data type:" << std::endl;
                status = run_sparse_matrix_sparse_matrix_dense_result_multiply_example<std::complex<double>, std::int32_t>(
                        my_dev);
                if (status != 0)
                    return status;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices "
                         "is enabled.\n";
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }

    mkl_free_buffers();
    return status;
}
