/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.evaluation;

import com.oracle.labs.mlrg.olcut.util.MutableDouble;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;

public final class RegressionSufficientStatistics {
    final int n;
    final ImmutableOutputInfo<Regressor> domain;
    final Map<String, MutableDouble> sumAbsoluteError = new LinkedHashMap<String, MutableDouble>();
    final Map<String, MutableDouble> sumSquaredError = new LinkedHashMap<String, MutableDouble>();
    final Map<String, double[]> predictedValues = new LinkedHashMap<String, double[]>();
    final Map<String, double[]> trueValues = new LinkedHashMap<String, double[]>();
    final float[] weights;
    final float weightSum;

    public RegressionSufficientStatistics(ImmutableOutputInfo<Regressor> domain, List<Prediction<Regressor>> predictions, boolean useExampleWeights) {
        this.domain = domain;
        this.n = predictions.size();
        this.weights = RegressionSufficientStatistics.initWeights(predictions, useExampleWeights);
        for (Regressor e : domain.getDomain()) {
            String name = e.getNames()[0];
            this.sumAbsoluteError.put(name, new MutableDouble());
            this.sumSquaredError.put(name, new MutableDouble());
            this.predictedValues.put(name, new double[this.n]);
            this.trueValues.put(name, new double[this.n]);
        }
        this.weightSum = this.tabulate(predictions);
    }

    private float tabulate(List<Prediction<Regressor>> predictions) {
        float weightSum = 0.0f;
        for (int i = 0; i < this.n; ++i) {
            Prediction<Regressor> prediction = predictions.get(i);
            float weight = this.weights[i];
            weightSum += weight;
            Regressor pred = (Regressor)prediction.getOutput();
            Regressor truth = (Regressor)prediction.getExample().getOutput();
            if (truth.equals(RegressionFactory.UNKNOWN_REGRESSOR)) {
                throw new IllegalArgumentException("The sentinel Unknown Regressor was used as a ground truth output at prediction number " + i);
            }
            if (pred.equals(RegressionFactory.UNKNOWN_REGRESSOR)) {
                throw new IllegalArgumentException("The sentinel Unknown Regressor was predicted by the model at prediction number " + i);
            }
            for (int j = 0; j < truth.size(); ++j) {
                String name = truth.getNames()[j];
                double trueValue = truth.getValues()[j];
                double predValue = pred.getValues()[j];
                double diff = trueValue - predValue;
                this.sumAbsoluteError.get(name).increment((double)weight * Math.abs(diff));
                this.sumSquaredError.get(name).increment((double)weight * diff * diff);
                this.trueValues.get((Object)name)[i] = trueValue;
                this.predictedValues.get((Object)name)[i] = predValue;
            }
        }
        return weightSum;
    }

    private static float[] initWeights(List<Prediction<Regressor>> predictions, boolean useExampleWeights) {
        float[] weights = new float[predictions.size()];
        if (useExampleWeights) {
            for (int i = 0; i < predictions.size(); ++i) {
                weights[i] = predictions.get(i).getExample().getWeight();
            }
        } else {
            Arrays.fill(weights, 1.0f);
        }
        return weights;
    }
}

