package it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm;

import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.learningalgorithm.KernelMethod;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.solver.LibNuSvmSolver;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.solver.SvmSolution;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel;

@JsonTypeName("binaryNuSvmClassification")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/libsvm/BinaryNuSvmClassification.class */
public class BinaryNuSvmClassification extends LibNuSvmSolver implements ClassificationLearningAlgorithm, KernelMethod {
    private float nu;
    private Label label;
    private BinaryKernelMachineClassifier classifier;

    public BinaryNuSvmClassification() {
        this.nu = 0.5f;
        initializeClassifier();
    }

    public BinaryNuSvmClassification(Kernel kernel, Label label, float f) {
        super(kernel, 1, 1);
        this.nu = 0.5f;
        this.label = label;
        this.nu = checkNu(f);
        initializeClassifier();
        setLabel(label);
    }

    private float checkNu(float f) {
        if (f > 0.0f && f < 1.0f) {
            return f;
        }
        System.err.println("Nu must be in the (0,1) interval. Nu is set to 0.5");
        return 0.5f;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryNuSvmClassification duplicate() {
        return new BinaryNuSvmClassification(this.kernel, this.label, this.nu);
    }

    public float getNu() {
        return this.nu;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryKernelMachineClassifier getPredictionFunction() {
        return this.classifier;
    }

    private void initializeClassifier() {
        BinaryKernelMachineModel binaryKernelMachineModel = new BinaryKernelMachineModel();
        this.classifier = new BinaryKernelMachineClassifier();
        this.classifier.setModel(binaryKernelMachineModel);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        int numberOfExamples = dataset.getNumberOfExamples();
        int[] iArr = new int[numberOfExamples];
        for (int i = 0; i < iArr.length; i++) {
            if (dataset.getExamples().get(i).isExampleOf(this.label)) {
                iArr[i] = 1;
            } else {
                iArr[i] = -1;
            }
        }
        float f = (this.nu * numberOfExamples) / 2.0f;
        float f2 = (this.nu * numberOfExamples) / 2.0f;
        float[] fArr = new float[numberOfExamples];
        for (int i2 = 0; i2 < numberOfExamples; i2++) {
            if (iArr[i2] == 1) {
                fArr[i2] = Math.min(1.0f, f);
                f -= fArr[i2];
            } else {
                fArr[i2] = Math.min(1.0f, f2);
                f2 -= fArr[i2];
            }
        }
        this.classifier.getModel().setKernel(this.kernel);
        learn(dataset, fArr);
    }

    private Classifier learn(Dataset dataset, float[] fArr) {
        this.l = dataset.getNumberOfExamples();
        float[] fArr2 = new float[this.l];
        for (int i = 0; i < this.l; i++) {
            fArr2[i] = 0.0f;
        }
        int[] iArr = new int[dataset.getNumberOfExamples()];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (dataset.getExamples().get(i2).isExampleOf(this.label)) {
                iArr[i2] = 1;
            } else {
                iArr[i2] = -1;
            }
        }
        SvmSolution solve = solve(dataset.getNumberOfExamples(), dataset, fArr2, iArr, fArr);
        float calculate_r = calculate_r();
        float[] alphas = solve.getAlphas();
        for (int i3 = 0; i3 < dataset.getNumberOfExamples(); i3++) {
            if (alphas[i3] != 0.0f) {
                this.classifier.getModel().addExample((iArr[i3] * alphas[i3]) / calculate_r, dataset.getExamples().get(i3));
            }
        }
        this.classifier.getModel().setBias((-solve.getRho()) / calculate_r);
        info("C = " + (1.0f / calculate_r));
        info("obj = " + (calculate_r * calculate_r));
        info("rho = " + ((-solve.getRho()) / calculate_r));
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.classifier.reset();
    }

    public void setNu(float f) {
        this.nu = checkNu(f);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.KernelMethod
    public void setKernel(Kernel kernel) {
        this.kernel = kernel;
    }
}
