/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.kernel;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.Util;
import org.tribuo.classification.sgd.kernel.KernelSVMModel;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public class KernelSVMTrainer
implements Trainer<Label>,
WeightedExamples {
    private static final Logger logger = Logger.getLogger(KernelSVMTrainer.class.getName());
    @Config(mandatory=true, description="SVM kernel.")
    private Kernel kernel;
    @Config(mandatory=true, description="Step size.")
    private double lambda;
    @Config(description="Number of SGD epochs.")
    private int epochs = 5;
    @Config(description="Log values after this many updates.")
    private int loggingInterval = -1;
    @Config(mandatory=true, description="Seed for the RNG used to shuffle elements.")
    private long seed;
    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
    private boolean shuffle = true;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    public KernelSVMTrainer(Kernel kernel, double lambda, int epochs, int loggingInterval, long seed) {
        this.kernel = kernel;
        this.lambda = lambda;
        this.epochs = epochs;
        this.loggingInterval = loggingInterval;
        this.seed = seed;
        this.postConfig();
    }

    public KernelSVMTrainer(Kernel kernel, double lambda, int epochs, long seed) {
        this(kernel, lambda, epochs, 1000, seed);
    }

    private KernelSVMTrainer() {
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public void setShuffle(boolean shuffle) {
        this.shuffle = shuffle;
    }

    public KernelSVMModel train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public KernelSVMModel train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        KernelSVMTrainer kernelSVMTrainer = this;
        synchronized (kernelSVMTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableOutputInfo labelIDMap = examples.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        SparseVector[] sgdFeatures = new SparseVector[examples.size()];
        int[] sgdLabels = new int[examples.size()];
        double[] weights = new double[examples.size()];
        int[] indices = new int[examples.size()];
        int n = 0;
        for (Example example : examples) {
            weights[n] = example.getWeight();
            sgdFeatures[n] = SparseVector.createSparseVector((Example)example, (ImmutableFeatureMap)featureIDMap, (boolean)true);
            sgdLabels[n] = labelIDMap.getID(example.getOutput());
            indices[n] = n;
            ++n;
        }
        logger.info(String.format("Training Kernel SVM with %d examples", n));
        logger.info(labelIDMap.toReadableString());
        double loss = 0.0;
        int iteration = 0;
        HashMap<Integer, SparseVector> supportVectors = new HashMap<Integer, SparseVector>();
        double[][] alphas = new double[labelIDMap.size()][examples.size()];
        for (int i = 0; i < this.epochs; ++i) {
            if (this.shuffle) {
                Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, indices, localRNG);
            }
            for (int j = 0; j < sgdFeatures.length; ++j) {
                SGDVector pred = this.predict(sgdFeatures[j], supportVectors, alphas);
                pred.add(sgdLabels[j], -1.0);
                int predIndex = pred.indexOfMax();
                if (sgdLabels[j] != predIndex) {
                    loss += (pred.get(sgdLabels[j]) - pred.get(predIndex)) * weights[j];
                    supportVectors.putIfAbsent(indices[j], sgdFeatures[j]);
                    double[] dArray = alphas[sgdLabels[j]];
                    int n2 = indices[j];
                    dArray[n2] = dArray[n2] + weights[j];
                }
                if (this.loggingInterval == -1 || ++iteration % this.loggingInterval != 0) continue;
                logger.info("At iteration " + iteration + ", average loss = " + loss / (double)this.loggingInterval + " with " + supportVectors.size() + " support vectors.");
                loss = 0.0;
            }
            logger.fine("Finished epoch " + i);
        }
        DenseMatrix alphaMatrix = new DenseMatrix(alphas.length, supportVectors.size());
        for (int i = 0; i < alphas.length; ++i) {
            int rowCounter = 0;
            for (int j = 0; j < sgdFeatures.length; ++j) {
                if (!supportVectors.containsKey(j)) continue;
                alphaMatrix.set(i, rowCounter, alphas[i][j]);
                ++rowCounter;
            }
        }
        int counter = 0;
        SparseVector[] supportArray = new SparseVector[supportVectors.size()];
        for (int i = 0; i < sgdFeatures.length; ++i) {
            SparseVector value = (SparseVector)supportVectors.get(i);
            if (value == null) continue;
            supportArray[counter] = value;
            ++counter;
        }
        ModelProvenance provenance = new ModelProvenance(KernelSVMModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        KernelSVMModel model = new KernelSVMModel("kernel-model", provenance, featureIDMap, (ImmutableOutputInfo<Label>)labelIDMap, this.kernel, supportArray, alphaMatrix);
        return model;
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    public String toString() {
        return "KernelSVMTrainer(kernel=" + this.kernel.toString() + ",lambda=" + this.lambda + ",epochs=" + this.epochs + ",seed=" + this.seed + ")";
    }

    private SGDVector predict(SparseVector features, Map<Integer, SparseVector> sv, double[][] alphas) {
        double[] score = new double[alphas.length];
        for (Map.Entry<Integer, SparseVector> e : sv.entrySet()) {
            double distance = this.kernel.similarity(features, e.getValue());
            for (int i = 0; i < alphas.length; ++i) {
                int n = i;
                score[n] = score[n] + alphas[i][e.getKey()] * distance;
            }
        }
        return DenseVector.createDenseVector((double[])score);
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

