/*******************************************************************************
* Copyright 2020-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       File contains service functionality and checkers for statistics of
*       various rng distributions
*
*******************************************************************************/

#ifndef __COMMON_FOR_RNG_DEVICE_EXAMPLES_HPP__
#define __COMMON_FOR_RNG_DEVICE_EXAMPLES_HPP__

// stl includes
#include <iostream>
#include <vector>
#include <math.h>

#include <sycl/sycl.hpp>
#include "oneapi/mkl/rng/device.hpp"

// local includes
#include "common_for_examples.hpp"

// function to print rng output from buffer
template<typename Type>
void print_output(sycl::buffer<Type, 1>& r, int n_print) {
    auto r_accessor = sycl::host_accessor(r, sycl::read_only);
    std::cout << "First "<< n_print << " numbers of " << r.size() << ": " << std::endl;
    for(int i = 0 ; i < n_print; i++) {
        std::cout << r_accessor[i] << " ";
    }
    std::cout << std::endl;
}

// function to compare theoretical moments and sample moments
template<typename Type, typename Allocator>
int compare_moments(const std::vector<Type, Allocator>& r, double tM, double tD, double tQ) {
    double tD2;
    double sM, sD;
    double sum, sum2;
    double n, s;
    double DeltaM, DeltaD;

    // sample moments
    sum = 0.0;
    sum2 = 0.0;
    for(std::size_t i = 0; i < r.size(); i++) {
        sum += (double)r[i];
        sum2 += (double)r[i] * (double)r[i];
    }
    sM = sum / ((double)r.size());
    sD = sum2 / (double)r.size() - (sM * sM);

    // comparison of theoretical and sample moments
    n = (double)r.size();
    tD2 = tD * tD;
    s = ((tQ-tD2) / n) - ( 2.0 * (tQ - 2.0 * tD2) / (n * n))+((tQ - 3.0 * tD2) /
                                                            (n * n * n));

    DeltaM = (tM - sM) / sqrt(tD / n);
    DeltaD = (tD - sD) / sqrt(s);
    if(fabs(DeltaM) > 3.0 || fabs(DeltaD) > 3.0) {
        std::cout << "Error: sample moments (mean=" << sM << ", variance=" << sD
            << ") disagree with theory (mean=" << tM << ", variance=" << tD <<
            ")" << std:: endl;
        return 1;
    }
    std::cout << "Success: sample moments (mean=" << sM << ", variance=" << sD
        << ") agree with theory (mean=" << tM << ", variance=" << tD <<
        ")" << std:: endl;
    return 0;
}

// structure is used to calculate theoretical moments of particular distribution
// and compare them with sample moments
template <typename Type, typename Distribution>
struct statistics {};

template<typename Type>
struct statistics<Type, oneapi::mkl::rng::device::uniform<Type>> {
    int check(const std::vector<Type>& r, const oneapi::mkl::rng::device::uniform<Type>& distr) {
        double tM, tD, tQ;
        double a = distr.a();
        double b = distr.b();
        if constexpr (std::is_integral<Type>::value) {
            // theoretical moments of uniform int (uint) distribution
            tM = (a + b - 1.0) / 2.0;
            tD = ((b - a)*(b - a) - 1.0) / 12.0;
            tQ = (((b - a) * (b - a)) * ((1.0 / 80.0)*(b - a)*(b - a) - (1.0 / 24.0))) + (7.0 / 240.0);
        } else {
            // theoretical moments of uniform real type distribution
            tM = (b + a) / 2.0;
            tD = ((b - a) * (b - a)) / 12.0;
            tQ = ((b - a)*(b - a)*(b - a)*(b - a)) / 80.0;
        }

        return compare_moments(r, tM, tD, tQ);
    }
};

template<typename Type, typename Method>
struct statistics<Type, oneapi::mkl::rng::device::gaussian<Type, Method>> {
    template<typename Allocator>
    int check(const std::vector<Type, Allocator>& r, const oneapi::mkl::rng::device::gaussian<Type, Method>& distr) {
        double tM, tD, tQ;
        double a = distr.mean();
        double sigma = distr.stddev();

        // theoretical moments of gaussian distribution
        tM = a;
        tD = sigma * sigma;
        tQ = 720.0 * sigma * sigma * sigma * sigma;

        return compare_moments(r, tM, tD, tQ);
    }
};

template<typename Type>
struct statistics<Type, oneapi::mkl::rng::device::lognormal<Type>> {
    int check(const std::vector<Type>& r, const oneapi::mkl::rng::device::lognormal<Type>& distr) {
        double tM, tD, tQ;
        double a = distr.m();
        double b = distr.displ();
        double sigma = distr.s();
        double beta = distr.scale();

        // theoretical moments of lognormal distribution
        tM = b + beta * std::exp(a + sigma * sigma * 0.5);
        tD = beta * beta * std::exp(2.0 * a + sigma * sigma) * (std::exp(sigma * sigma) - 1.0);
        tQ = beta * beta * beta * beta * std::exp(4.0 * a + 2.0 * sigma * sigma) *
            (std::exp(6.0 * sigma * sigma) - 4.0 * std::exp(3.0 * sigma * sigma) + 6.0 * std::exp(sigma * sigma) - 3.0);

        return compare_moments(r, tM, tD, tQ);
    }
};

template <typename Type>
struct statistics<Type, oneapi::mkl::rng::device::exponential<Type>> {
    int check(const std::vector<Type>& r, const oneapi::mkl::rng::device::exponential<Type>& distr) {
        double tM, tD, tQ;
        double a = distr.a();
        double beta = distr.beta();

        tM = a + beta;
        tD = beta * beta;
        tQ = 9.0 * beta * beta * beta * beta;

        return compare_moments(r, tM, tD, tQ);
    }
};

template <typename Type>
struct statistics<Type, oneapi::mkl::rng::device::beta<Type>> {
    int check(const std::vector<Type>& r, const oneapi::mkl::rng::device::beta<Type>& distr) {
        double tM, tD, tQ;
        double b, c, d, e, e2, b2, sum_pq;
        double p = distr.p();
        double q = distr.q();
        double a = distr.a();
        double beta = distr.b();

        b2 = beta * beta;
        sum_pq = p + q;
        b = (p + 1.0) / (sum_pq + 1.0);
        c = (p + 2.0) / (sum_pq + 2.0);
        d = (p + 3.0) / (sum_pq + 3.0);
        e = p / sum_pq;
        e2 = e * e;

        tM = a + e * beta;
        tD = b2 * p * q / (sum_pq * sum_pq * (sum_pq + 1.0));
        tQ = b2 * b2 * (e * b * c * d - 4.0 * e2 * b * c + 6.0 * e2 * e * b - 3.0 * e2 * e2);

        return compare_moments(r, tM, tD, tQ);
    }
};

template <typename Type>
struct statistics<Type, oneapi::mkl::rng::device::gamma<Type>> {
    int check(const std::vector<Type>& r, const oneapi::mkl::rng::device::gamma<Type>& distr) {
        double tM, tD, tQ;
        double a = distr.a();
        double alpha = distr.alpha();
        double beta = distr.beta();

        tM = a + beta * alpha;
        tD = beta * beta * alpha;
        tQ = beta * beta * beta * beta * 3.0 * alpha * (alpha + 2.0);

        return compare_moments(r, tM, tD, tQ);
    }
};

template <typename Type>
struct statistics<Type, oneapi::mkl::rng::device::poisson<Type>> {
    int check(const std::vector<Type>& r,
               const oneapi::mkl::rng::device::poisson<Type>& distr) {
        double tM, tD, tQ;
        double lambda = distr.lambda();

        tM = lambda;
        tD = lambda;
        tQ = 4.0 * lambda * lambda + lambda;

        return compare_moments(r, tM, tD, tQ);
    }
};

template <typename Type>
struct statistics<Type, oneapi::mkl::rng::device::bernoulli<Type>> {
    int check(const std::vector<Type>& r,
               const oneapi::mkl::rng::device::bernoulli<Type>& distr) {
        double tM, tD, tQ;
        double p = distr.p();

        tM = p;
        tD = p * (1.0 - p);
        tQ = p * (1.0 - 4.0 * p + 6.0 * p * p - 3.0 * p * p * p);

        return compare_moments(r, tM, tD, tQ);
    }
};

#endif // __COMMON_FOR_RNG_DEVICE_EXAMPLES_HPP__
