/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.training.loss.Loss;

public class MaskedSoftmaxCrossEntropyLoss
extends Loss {
    private float weight;
    private int classAxis;
    private boolean sparseLabel;
    private boolean fromLogit;

    public MaskedSoftmaxCrossEntropyLoss() {
        this("MaskedSoftmaxCrossEntropyLoss");
    }

    public MaskedSoftmaxCrossEntropyLoss(String name) {
        this(name, 1.0f, -1, true, false);
    }

    public MaskedSoftmaxCrossEntropyLoss(String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
        super(name);
        this.weight = weight;
        this.classAxis = classAxis;
        this.sparseLabel = sparseLabel;
        this.fromLogit = fromLogit;
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        NDArray loss;
        NDArray weights = labels.head().onesLike().expandDims(-1).sequenceMask((NDArray)labels.get(1));
        NDArray pred = predictions.singletonOrThrow();
        if (!this.fromLogit) {
            pred = pred.logSoftmax(this.classAxis);
        }
        NDArray lab = labels.head();
        if (this.sparseLabel) {
            NDIndex pickIndex = new NDIndex().addAllDim(Math.floorMod(this.classAxis, pred.getShape().dimension())).addPickDim(lab);
            loss = pred.get(pickIndex).neg();
        } else {
            lab = lab.reshape(pred.getShape());
            loss = pred.mul(lab).neg().sum(new int[]{this.classAxis}, true);
        }
        loss = loss.mul(weights);
        if (this.weight != 1.0f) {
            loss = loss.mul(Float.valueOf(this.weight));
        }
        return loss.mean(new int[]{1});
    }
}

