package it.uniroma2.sag.kelp.learningalgorithm.classification.pegasos;

import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.data.representation.Vector;
import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.random.EmpiricalDistribution;

@JsonTypeName("pegasos")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/pegasos/PegasosLearningAlgorithm.class */
public class PegasosLearningAlgorithm implements LinearMethod, ClassificationLearningAlgorithm, BinaryLearningAlgorithm {
    private Label label;
    private String representation;
    private int k = 1;
    private int iterations = EmpiricalDistribution.DEFAULT_BIN_COUNT;
    private float lambda = 0.01f;
    private BinaryLinearClassifier classifier = new BinaryLinearClassifier();

    public int getK() {
        return this.k;
    }

    public void setK(int i) {
        this.k = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setIterations(int i) {
        this.iterations = i;
    }

    public float getLambda() {
        return this.lambda;
    }

    public void setLambda(float f) {
        this.lambda = f;
    }

    public PegasosLearningAlgorithm() {
        this.classifier.setModel(new BinaryLinearModel());
    }

    public PegasosLearningAlgorithm(int i, float f, int i2, String str, Label label) {
        this.classifier.setModel(new BinaryLinearModel());
        setK(i);
        setLabel(label);
        setLambda(f);
        setRepresentation(str);
        setIterations(i2);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LinearMethod
    public String getRepresentation() {
        return this.representation;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LinearMethod
    public void setRepresentation(String str) {
        this.representation = str;
        this.classifier.getModel().setRepresentation(str);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        if (getPredictionFunction().getModel().getHyperplane() == null) {
            getPredictionFunction().getModel().setHyperplane(dataset.getZeroVector(this.representation));
        }
        for (int i = 1; i <= this.iterations; i++) {
            List<Example> randExamples = dataset.getRandExamples(this.k);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            float f = 1.0f / (this.lambda * i);
            Vector hyperplane = getPredictionFunction().getModel().getHyperplane();
            for (Example example : randExamples) {
                BinaryMarginClassifierOutput predict = this.classifier.predict(example);
                float f2 = example.isExampleOf(this.label) ? 1.0f : -1.0f;
                if (predict.getScore(this.label).floatValue() * f2 < 1.0f) {
                    arrayList.add(example);
                    arrayList2.add(Float.valueOf(f2));
                }
            }
            hyperplane.scale(1.0f - (f * this.lambda));
            float f3 = f / this.k;
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                getPredictionFunction().getModel().addExample(((Float) arrayList2.get(i2)).floatValue() * f3, (Example) arrayList.get(i2));
            }
            float sqrt = (float) ((1.0d / Math.sqrt(this.lambda)) / Math.sqrt(hyperplane.getSquaredNorm()));
            if (sqrt < 1.0f) {
                hyperplane.scale(sqrt);
            }
        }
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public PegasosLearningAlgorithm duplicate() {
        PegasosLearningAlgorithm pegasosLearningAlgorithm = new PegasosLearningAlgorithm();
        pegasosLearningAlgorithm.setK(this.k);
        pegasosLearningAlgorithm.setLambda(this.lambda);
        pegasosLearningAlgorithm.setIterations(this.iterations);
        pegasosLearningAlgorithm.setRepresentation(this.representation);
        return pegasosLearningAlgorithm;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.classifier.reset();
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryLinearClassifier getPredictionFunction() {
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setLabels(List<Label> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("Pegasos algorithm is a binary method which can learn a single Label");
        }
        this.label = list.get(0);
        this.classifier.setLabels(list);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public List<Label> getLabels() {
        return Arrays.asList(this.label);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public Label getLabel() {
        return this.label;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public void setLabel(Label label) {
        setLabels(Arrays.asList(label));
    }
}
