/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of gaussian distribution using Device
*       Vector Math APIs on a SYCL device (CPU, GPU).
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>
#include <numeric>
#include <limits>

#include <sycl/sycl.hpp>

// the macro is needed to switch the Vector Math APIs based implementation on
#define MKL_RNG_USE_BINARY_CODE 1
#include "oneapi/mkl/rng/device.hpp"

// local includes
#include "common_for_rng_examples.hpp"

// example parameters
constexpr int seed = 777;
constexpr std::size_t n = 1'000'000;
constexpr int n_print = 10;

// create an alias for the namespace to make it shorter
namespace rng_device = oneapi::mkl::rng::device;

template <typename Method, typename Type>
int run_scalar_example(sycl::queue& q, Type mean, Type stddev) {
    std::cout << "\tRunning scalar example" << std::endl;
    // prepare array for random numbers
    sycl::usm_allocator<Type, sycl::usm::alloc::shared> allocator_gpu(q);
    std::vector<Type, decltype(allocator_gpu)> r_vec(n, allocator_gpu);
    Type* r = r_vec.data();
    using Distr = rng_device::gaussian<Type, Method>;

    // submit a kernel to generate numbers on device
    try {
        q.parallel_for(sycl::range<1>(n), [=](sycl::id<> id) {
            auto idx = id.get(0);
            rng_device::mcg59 engine(seed, idx);
            Distr distr(mean, stddev);

            r[idx] = rng_device::generate(distr, engine);
        })
        .wait_and_throw();
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
        return 1;
    }

    std::cout << "\t\tOutput of generator:" << std::endl;
    std::cout << "first " << n_print << " numbers of " << n << ": " << std::endl;
    for (int i = 0; i < n_print; i++) {
        std::cout << r[i] << " ";
    }
    std::cout << std::endl;

    // check if mean and stddev are the same as we expect
    if(statistics<Type, Distr>{}.check(r_vec, Distr(mean, stddev)) == 1){
        return 1;
    }
    return 0;
}

template <int VecSize, typename Method, typename Type>
int run_vector_example(sycl::queue& q, Type mean, Type stddev) {
    std::cout << "\tRunning vector example with " << VecSize << " vector size" << std::endl;
    // prepare array for random numbers
    sycl::usm_allocator<Type, sycl::usm::alloc::shared> allocator_gpu(q);
    std::vector<Type, decltype(allocator_gpu)> r_vec(n, allocator_gpu);
    Type* r = r_vec.data();
    using Distr = rng_device::gaussian<Type, Method>;

    // submit a kernel to generate numbers on device
    try {
        q.parallel_for(sycl::range<1>(n / VecSize), [=](sycl::id<> id) {
            auto idx = id.get(0);
            rng_device::mcg59<VecSize> engine(seed, idx * VecSize);
            Distr distr(mean, stddev);

            sycl::vec<Type, VecSize> res = rng_device::generate(distr, engine);

            for(int i = 0; i < VecSize; ++i){
                r[idx * VecSize + i] = res[i];
            }
        })
        .wait_and_throw();
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception\n" << e.what() << std::endl;
        return 1;
    }

    std::cout << "\t\tOutput of generator:" << std::endl;
    std::cout << "first " << n_print << " numbers of " << n << ": " << std::endl;
    for (int i = 0; i < n_print; i++) {
        std::cout << r[i] << " ";
    }
    std::cout << std::endl;

    // check if mean and stddev are the same as we expect
    if(statistics<Type, Distr>{}.check(r_vec, Distr(mean, stddev)) == 1){
        return 1;
    }
    return 0;
}

//
// description of example setup, APIs used
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Generate normally distributed random numbers example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using APIs:" << std::endl;
    std::cout << "#   mcg59 gaussian" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//

int main() {
    print_example_banner();

    // handler to catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception:\n" << e.what() << std::endl;
            }
        }
    };

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);
        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            sycl::queue q(my_dev, exception_handler);

            std::cout << "\n\tRunning with single precision real data type:" << std::endl;
            std::cout << "\n\tmcg59 generator" << std::endl;
            if (run_scalar_example<oneapi::mkl::rng::device::gaussian_method::box_muller2>(q, 10.0f, 20.0f) ||
                run_scalar_example<oneapi::mkl::rng::device::gaussian_method::icdf>(q, 11.0f, 21.0f) ||
                run_vector_example</*vec_size*/16, oneapi::mkl::rng::device::gaussian_method::box_muller2>(q, 12.0f, 22.0f) ||
                run_vector_example</*vec_size*/16, oneapi::mkl::rng::device::gaussian_method::icdf>(q, 13.0f, 23.0f)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
            if (isDoubleSupported(my_dev)) {
                std::cout << "\n\tRunning with double precision real data type:" << std::endl;
                std::cout << "\n\tmcg59 generator" << std::endl;
                if (run_scalar_example<oneapi::mkl::rng::device::gaussian_method::box_muller2>(q, 10.0, 20.0) ||
                    run_scalar_example<oneapi::mkl::rng::device::gaussian_method::icdf>(q, 11.0, 21.0) ||
                    run_vector_example</*vec_size*/16, oneapi::mkl::rng::device::gaussian_method::box_muller2>(q, 12.0, 22.0) ||
                    run_vector_example</*vec_size*/16, oneapi::mkl::rng::device::gaussian_method::icdf>(q, 13.0, 23.0)) {
                    std::cout << "FAILED" << std::endl;
                    return 1;
                }
            }
            else {
                std::cout << "Double precision is not supported for this device" << std::endl;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
