/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.modality.nlp.generate.BatchTensorList;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class ContrastiveBatchTensorList
extends BatchTensorList {
    private NDArray pastHiddenStates;
    private NDArray logits;

    ContrastiveBatchTensorList(NDList list, long[] seqDimOrder) {
        super((NDArray)list.get(0), (NDArray)list.get(1), list.subNDList(4), seqDimOrder);
        this.pastHiddenStates = (NDArray)list.get(2);
        this.logits = (NDArray)list.get(3);
    }

    ContrastiveBatchTensorList(NDArray pastOutputIds, NDArray pastAttentionMask, NDArray pastHiddenStates, NDArray logits, NDList pastKeyValues, long[] seqDimOrder) {
        super(pastOutputIds, pastAttentionMask, pastKeyValues, seqDimOrder);
        this.pastHiddenStates = pastHiddenStates;
        this.logits = logits;
    }

    public ContrastiveBatchTensorList() {
    }

    @Override
    public ContrastiveBatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
        return new ContrastiveBatchTensorList(inputList, seqDimOrder);
    }

    @Override
    public NDList getList() {
        return new NDList(this.getPastOutputIds(), this.getPastAttentionMask(), this.getPastHiddenStates(), this.getLogits()).addAll(this.getPastKeyValues());
    }

    public NDArray getPastHiddenStates() {
        return this.pastHiddenStates;
    }

    public void setPastHiddenStates(NDArray pastHiddenStates) {
        this.pastHiddenStates = pastHiddenStates;
    }

    public NDArray getLogits() {
        return this.logits;
    }

    public void setLogits(NDArray logits) {
        this.logits = logits;
    }
}

