/*******************************************************************************
* Copyright 2018-2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
 *
 *  Content:
 *            demonstration of usage of VM and RNG device APIs:
 *            Normal N(0, 1) random numbers generation by two ways:
 *            direct RNG Gaussian call and ICDF method implementation
 *            via combined RNG Uniform and VM CdfNormInv calls
 *
 *******************************************************************************/

#include <algorithm>
#include <numeric>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <iomanip>
#include <random>
#include <stdexcept>
#include <string>
#include <type_traits>

#include <sycl/sycl.hpp>
#include "oneapi/mkl/vm.hpp"
#include "oneapi/mkl/vm/device/vm.hpp"
#include "oneapi/mkl/rng/device.hpp"

#include "common_for_examples.hpp"

namespace {

using std::int64_t;
using std::uint32_t;
using std::uint64_t;

void own_preamble (sycl::device & dev)
{
    auto plat = dev.get_platform();
    std::cout << "\t        platform: " << plat.template get_info<sycl::info::platform::name>() << std::endl;
    std::cout << "\tplatform_version: " << plat.template get_info<sycl::info::platform::version>() << std::endl;
    std::cout << "\t          device: " << dev.template get_info<sycl::info::device::name>() << std::endl;
    std::cout << "\t  driver_version: " << dev.template get_info<sycl::info::device::driver_version>() << std::endl;
}

void async_sycl_error(sycl::exception_list el) {
    std::cerr << "async exceptions caught: " << std::endl;

    for (auto l = el.begin(); l != el.end(); ++l) {
        try {
            std::rethrow_exception(*l);
        } catch(const sycl::exception & e) {
            std::cerr << "SYCL exception occured with code " << e.code().value() << " with " << e.what() << std::endl;
        }
    }
}

bool check_mean(double mean, double expected, double sigma, int64_t n) {
    double adiff = std::fabs(mean - expected);
    double err_estimate = sigma / std::sqrt(n * 1.0);
    return (adiff / err_estimate) < 3.0;
}

bool check_stddev(double std_dev, double expected, double sigma, int64_t n) {
    double adiff = std::fabs(std_dev - expected);
    double err_estimate = sigma / std::sqrt(2 * (n - 1.0));
    return (adiff / err_estimate) < 3.0;
}

template<typename T>
bool print_results(const char * method, int64_t n, T * y) {
    double s1 = std::accumulate(y, y + n, 0.0, [=](double s, double t) { return s + t; });
    double s2 = std::accumulate(y, y + n, 0.0, [=](double s, double t) { return s + t * t; });

    double mean     = s1 / n;
    double stddev   = std::sqrt((s2 - s1  * s1 / n) / (n - 1));

    std::string float_type_string { (sizeof(T) == 4) ? "float" : "double" };

    auto mean_ok   = (check_mean(mean, 0.0, 1.0, n));
    auto stddev_ok = (check_stddev(stddev, 1.0, 1.0, n));

    std::cout << "\t" << std::setw(6) << float_type_string
              << method
              << "       n = " << std::setw(10) << n
              << " mean    = " << std::setw(16) << mean
              << std::setw(10) << (mean_ok ? "( PASS )" : "( FAIL )")
              << " std.dev = " << std::setw(16) << stddev
              << std::setw(10) << (stddev_ok ? "( PASS )" : "( FAIL )")
              << std::endl;

    return (mean_ok && stddev_ok);
}

template <typename T> class VsNormal;
template <typename T> class VmNormal;

static const uint64_t seed = 777777;

template <typename T>
bool run_vm_device(int64_t n, sycl::queue & queue)
{
    T * y = new T[n];
    T * dev_y = sycl::malloc_device<T>(n, queue);

    std::fill(y, y + n, std::nan(""));

    auto vm_kernel = [=](sycl::id<1> id)
    {
        size_t i = id.get(0);

        // oneMKL RNG device API for uniform numbers generation
        oneapi::mkl::rng::device::philox4x32x10<> engine (seed, i);
        oneapi::mkl::rng::device::uniform<T>      distr  (0.0, 1.0);
        T u = oneapi::mkl::rng::device::generate (distr, engine);

        // oneMKL VM CdfNorm cumulative normal distribution function
        oneapi::mkl::vm::device::cdfnorminv (&u, &dev_y[i], oneapi::mkl::vm::device::mode::la);
    };

    // VM device kernel execute
    auto ev1 = queue.parallel_for<class VmNormal<T>> (n, vm_kernel);
    // Transfer results from device
    auto ev2 = queue.memcpy (y, dev_y, n * sizeof(T), ev1);
    // Wait for async copy complete
    ev2.wait ();

    auto pass = print_results(" uniform+cdfnorminv : ", n, y);

    sycl::free(dev_y, queue);
    delete[] y;

    return pass;
}

template <typename T>
bool run_vs_device(int64_t n, sycl::queue & queue)
{
    T * y = new T[n];
    T * dev_y = sycl::malloc_device<T>(n, queue);

    std::fill(y, y + n, std::nan(""));

    auto vm_kernel = [=](sycl::id<1> id)
    {
        size_t i = id.get(0);

        // oneMKL RNG device API for normal random values generation
        oneapi::mkl::rng::device::philox4x32x10<> engine (seed, i);
        oneapi::mkl::rng::device::gaussian<T>     distr  (0.0, 1.0);
        dev_y[i] = oneapi::mkl::rng::device::generate (distr, engine);
    };

    // VM device kernel execute
    auto ev1 = queue.parallel_for<class VsNormal<T>> (n, vm_kernel);
    // Transfer results from device
    auto ev2 = queue.memcpy (y, dev_y, n * sizeof(T), ev1);
    // Wait for async copy complete
    ev2.wait ();

    auto pass = print_results(" gaussian           : ", n, y);

    sycl::free(dev_y, queue);
    delete[] y;

    return pass;
}



int own_run_on(sycl::device & dev)
{
    bool pass = true;

    double mean     = std::nan("");
    double std_dev  = std::nan("");

    constexpr int vector_stack_len   = 1024;
    int64_t       vector_heap_len    = 10'000'000;
    int64_t       vector_buffer_len  = 10'000'000;
    int64_t       vector_usm_len     = 10'000'000;

    own_preamble(dev);

    sycl::queue queue { dev, async_sycl_error };

    std::cout << std::fixed << std::setprecision(10);

    pass &= run_vm_device<float> (vector_usm_len, queue);
    pass &= run_vs_device<float> (vector_usm_len, queue);
    pass &= run_vm_device<double>(vector_usm_len, queue);
    pass &= run_vs_device<double>(vector_usm_len, queue);

    std::cout << std::endl << std::endl;

    return (pass == true)?0:-1;
} // int own_run_on(sycl::device & dev)
} // namespace

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, the example
//  runs with all supported data types
//
int main (int argc, char **argv)
{
    int ret = 0; // return status
    fprintf (stdout, "sycl vm_device_api_demo: started...\n"); fflush (stdout);

    // List of available devices
    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices (list_of_devices);

    // Loop by all available devices
    for (auto dev_type : list_of_devices)
    {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device (my_dev, my_dev_is_found, dev_type);

        // Run tests if the device is available
        if (my_dev_is_found)
        {
            fprintf (stdout, "Running tests on %s.\n", sycl_device_names[dev_type].c_str()); fflush (stdout);
            try {
                ret |= own_run_on (my_dev);
            } catch (sycl::exception const& e) {
                fprintf (stderr, "sycl::exception caught. %s\n", e.what());
                ret = 1;
            } catch (std::exception const& e) {
                fprintf (stderr, "std::exception caught. %s\n", e.what());
                ret = 1;
            }
        }
        else
        {
            fprintf (stderr, "No %s devices found; skipping %s tests.\n",
                sycl_device_names[dev_type].c_str(), sycl_device_names[dev_type].c_str());
        }
    }

    fflush (stdout); fprintf (stdout, "sycl vm_device_api_demo: %s\n\n", (ret != 0)?"FAIL":"PASS");
    return ret;
}

