/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.classification.Label;
import org.tribuo.classification.example.DemoLabelDataSource;
import org.tribuo.impl.ArrayExample;

public final class GaussianLabelDataSource
extends DemoLabelDataSource {
    @Config(mandatory=true, description="2d mean of the first Gaussian.")
    private double[] firstMean;
    @Config(mandatory=true, description="4 element covariance matrix of the first Gaussian.")
    private double[] firstCovarianceMatrix;
    @Config(mandatory=true, description="2d mean of the second Gaussian.")
    private double[] secondMean;
    @Config(mandatory=true, description="4 element covariance matrix of the second Gaussian.")
    private double[] secondCovarianceMatrix;
    private double[] firstCholesky;
    private double[] secondCholesky;

    private GaussianLabelDataSource() {
    }

    public GaussianLabelDataSource(int numSamples, long seed, double[] firstMean, double[] firstCovarianceMatrix, double[] secondMean, double[] secondCovarianceMatrix) {
        super(numSamples, seed);
        this.firstMean = firstMean;
        this.firstCovarianceMatrix = firstCovarianceMatrix;
        this.secondMean = secondMean;
        this.secondCovarianceMatrix = secondCovarianceMatrix;
        this.postConfig();
    }

    @Override
    public void postConfig() {
        if (this.firstMean.length != 2) {
            throw new PropertyException("", "firstMean", "firstMean is not the right length");
        }
        if (this.secondMean.length != 2) {
            throw new PropertyException("", "secondMean", "secondMean is not the right length");
        }
        if (this.firstCovarianceMatrix.length != 4) {
            throw new PropertyException("", "firstCovarianceMatrix", "firstCovarianceMatrix is not the right length");
        }
        if (this.secondCovarianceMatrix.length != 4) {
            throw new PropertyException("", "secondCovarianceMatrix", "secondCovarianceMatrix is not the right length");
        }
        for (int i = 0; i < this.firstCovarianceMatrix.length; ++i) {
            if (this.firstCovarianceMatrix[i] < 0.0) {
                throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not positive semi-definite");
            }
            if (!(this.secondCovarianceMatrix[i] < 0.0)) continue;
            throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not positive semi-definite");
        }
        if (this.firstCovarianceMatrix[1] != this.firstCovarianceMatrix[2]) {
            throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not a covariance matrix");
        }
        if (this.secondCovarianceMatrix[1] != this.secondCovarianceMatrix[2]) {
            throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not a covariance matrix");
        }
        this.firstCholesky = new double[3];
        this.firstCholesky[0] = Math.sqrt(this.firstCovarianceMatrix[0]);
        this.firstCholesky[1] = this.firstCovarianceMatrix[1] / Math.sqrt(this.firstCovarianceMatrix[0]);
        this.firstCholesky[2] = Math.sqrt(this.firstCovarianceMatrix[3] * this.firstCovarianceMatrix[0] - this.firstCovarianceMatrix[1] * this.firstCovarianceMatrix[1]) / Math.sqrt(this.firstCovarianceMatrix[0]);
        this.secondCholesky = new double[3];
        this.secondCholesky[0] = Math.sqrt(this.secondCovarianceMatrix[0]);
        this.secondCholesky[1] = this.secondCovarianceMatrix[1] / Math.sqrt(this.secondCovarianceMatrix[0]);
        this.secondCholesky[2] = Math.sqrt(this.secondCovarianceMatrix[3] * this.secondCovarianceMatrix[0] - this.secondCovarianceMatrix[1] * this.secondCovarianceMatrix[1]) / Math.sqrt(this.secondCovarianceMatrix[0]);
        super.postConfig();
    }

    @Override
    protected List<Example<Label>> generate() {
        ArrayExample datapoint;
        double[] sample;
        int i;
        ArrayList<Example<Label>> list = new ArrayList<Example<Label>>();
        for (i = 0; i < this.numSamples / 2; ++i) {
            sample = GaussianLabelDataSource.sampleGaussian(this.rng, this.firstMean, this.firstCholesky);
            datapoint = new ArrayExample((Output)FIRST_CLASS, FEATURE_NAMES, sample);
            list.add((Example<Label>)datapoint);
        }
        for (i = this.numSamples / 2; i < this.numSamples; ++i) {
            sample = GaussianLabelDataSource.sampleGaussian(this.rng, this.secondMean, this.secondCholesky);
            datapoint = new ArrayExample((Output)SECOND_CLASS, FEATURE_NAMES, sample);
            list.add((Example<Label>)datapoint);
        }
        return list;
    }

    private static double[] sampleGaussian(Random rng, double[] means, double[] cholesky) {
        double[] sample = new double[2];
        double first = rng.nextGaussian();
        sample[0] = means[0] + first * cholesky[0];
        double second = rng.nextGaussian();
        sample[1] = means[1] + first * cholesky[1] + second * cholesky[2];
        return sample;
    }

    public String toString() {
        String sb = "GaussianGenerator(numSamples=" + this.numSamples + ",seed=" + this.seed + ",firstMean=" + Arrays.toString(this.firstMean) + ",firstCovarianceMatrix=" + Arrays.toString(this.firstCovarianceMatrix) + ",secondMean=" + Arrays.toString(this.secondMean) + ",secondCovarianceMatrix=" + Arrays.toString(this.secondCovarianceMatrix) + ')';
        return sb;
    }
}

