/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.sgd;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.Matrix;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

public abstract class AbstractLinearSGDModel<T extends Output<T>>
extends AbstractSGDModel<T> {
    private static final long serialVersionUID = 1L;

    protected AbstractLinearSGDModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, LinearParameters parameters, boolean generatesProbabilities) {
        super(name, provenance, featureIDMap, outputIDInfo, (FeedForwardParameters)parameters, generatesProbabilities, true);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        DenseMatrix baseWeights = (DenseMatrix)this.modelParameters.get()[0];
        int maxFeatures = n < 0 ? this.featureIDMap.size() + 1 : n;
        Comparator<Pair> comparator = Comparator.comparingDouble(p -> Math.abs((Double)p.getB()));
        int numClasses = baseWeights.getDimension1Size();
        int numFeatures = baseWeights.getDimension2Size() - 1;
        HashMap<String, List<Pair<String, Double>>> map = new HashMap<String, List<Pair<String, Double>>>();
        for (int i = 0; i < numClasses; ++i) {
            PriorityQueue<Pair> q = new PriorityQueue<Pair>(maxFeatures, comparator);
            for (int j = 0; j < numFeatures; ++j) {
                Pair curr = new Pair((Object)this.featureIDMap.get(j).getName(), (Object)baseWeights.get(i, j));
                if (q.size() < maxFeatures) {
                    q.offer(curr);
                    continue;
                }
                if (comparator.compare(curr, q.peek()) <= 0) continue;
                q.poll();
                q.offer(curr);
            }
            Pair curr = new Pair((Object)"BIAS", (Object)baseWeights.get(i, numFeatures));
            if (q.size() < maxFeatures) {
                q.offer(curr);
            } else if (comparator.compare(curr, q.peek()) > 0) {
                q.poll();
                q.offer(curr);
            }
            ArrayList<Pair> b = new ArrayList<Pair>();
            while (q.size() > 0) {
                b.add(q.poll());
            }
            Collections.reverse(b);
            map.put(this.getDimensionName(i), b);
        }
        return map;
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        DenseMatrix baseWeights = (DenseMatrix)this.modelParameters.get()[0];
        Prediction prediction = this.predict(example);
        HashMap weightMap = new HashMap();
        int numClasses = baseWeights.getDimension1Size();
        int numFeatures = baseWeights.getDimension2Size() - 1;
        for (int i = 0; i < numClasses; ++i) {
            ArrayList<Pair> classScores = new ArrayList<Pair>();
            for (Feature f : example) {
                int id = this.featureIDMap.getID(f.getName());
                if (id <= -1) continue;
                double score = baseWeights.get(i, id) * f.getValue();
                classScores.add(new Pair((Object)f.getName(), (Object)score));
            }
            classScores.add(new Pair((Object)"BIAS", (Object)baseWeights.get(i, numFeatures)));
            classScores.sort((o1, o2) -> ((Double)o2.getB()).compareTo((Double)o1.getB()));
            weightMap.put(this.getDimensionName(i), classScores);
        }
        return Optional.of(new Excuse(example, prediction, weightMap));
    }

    protected abstract String getDimensionName(int var1);

    public DenseMatrix getWeightsCopy() {
        return ((DenseMatrix)this.modelParameters.get()[0]).copy();
    }

    protected abstract ONNXNode onnxOutput(ONNXNode var1);

    protected abstract String onnxModelName();

    public ONNXNode writeONNXGraph(ONNXRef<?> input) {
        ONNXContext onnx = input.onnxContext();
        Matrix weightMatrix = (Matrix)this.modelParameters.get()[0];
        ONNXInitializer weights = onnx.floatTensor("linear_sgd_weights", Arrays.asList(this.featureIDMap.size(), this.outputIDInfo.size()), fb -> {
            for (int j = 0; j < weightMatrix.getDimension2Size() - 1; ++j) {
                for (int i = 0; i < weightMatrix.getDimension1Size(); ++i) {
                    fb.put((float)weightMatrix.get(i, j));
                }
            }
        });
        ONNXInitializer bias = onnx.floatTensor("linear_sgd_bias", Collections.singletonList(this.outputIDInfo.size()), fb -> {
            for (int i = 0; i < weightMatrix.getDimension1Size(); ++i) {
                fb.put((float)weightMatrix.get(i, weightMatrix.getDimension2Size() - 1));
            }
        });
        return this.onnxOutput(input.apply(ONNXOperators.GEMM, Arrays.asList(weights, bias)));
    }

    public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
        ONNXContext onnx = new ONNXContext();
        onnx.setName(this.onnxModelName());
        ONNXPlaceholder input = onnx.floatInput("input", this.featureIDMap.size());
        ONNXPlaceholder output = onnx.floatOutput("output", this.outputIDInfo.size());
        this.writeONNXGraph((ONNXRef<?>)input).assignTo((ONNXRef)output);
        return ONNXExportable.buildModel((ONNXContext)onnx, (String)domain, (long)modelVersion, (Provenancable)this);
    }
}

