package it.uniroma2.sag.kelp.learningalgorithm.classification.perceptron;

import com.fasterxml.jackson.annotation.JsonIgnore;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryClassifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/perceptron/Perceptron.class */
public abstract class Perceptron implements ClassificationLearningAlgorithm, OnlineLearningAlgorithm, BinaryLearningAlgorithm {

    @JsonIgnore
    protected BinaryClassifier classifier;
    protected Label label;
    protected float alpha = 1.0f;
    protected float margin = 1.0f;
    protected boolean unbiased = false;

    public float getAlpha() {
        return this.alpha;
    }

    public void setAlpha(float f) {
        if (f <= 0.0f || f > 1.0f) {
            throw new IllegalArgumentException("Invalid learning rate for the perceptron algorithm: valid alphas in (0,1]");
        }
        this.alpha = f;
    }

    public float getMargin() {
        return this.margin;
    }

    public void setMargin(float f) {
        this.margin = f;
    }

    public boolean isUnbiased() {
        return this.unbiased;
    }

    public void setUnbiased(boolean z) {
        this.unbiased = z;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        while (dataset.hasNextExample()) {
            learn(dataset.getNextExample());
        }
        dataset.reset();
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm
    public BinaryMarginClassifierOutput learn(Example example) {
        BinaryMarginClassifierOutput predict = this.classifier.predict(example);
        if (Math.abs(predict.getScore(this.label).floatValue()) < this.margin || predict.isClassPredicted(this.label) != example.isExampleOf(this.label)) {
            float f = this.alpha;
            if (!example.isExampleOf(this.label)) {
                f = -this.alpha;
            }
            this.classifier.getModel().addExample(f, example);
            if (!this.unbiased) {
                this.classifier.getModel().setBias(this.classifier.getModel().getBias() + f);
            }
        }
        return predict;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.classifier.reset();
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryClassifier getPredictionFunction() {
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setLabels(List<Label> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("The Perceptron algorithm is a binary method which can learn a single Label");
        }
        this.label = list.get(0);
        this.classifier.setLabels(list);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public List<Label> getLabels() {
        return Arrays.asList(this.label);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public Label getLabel() {
        return this.label;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public void setLabel(Label label) {
        setLabels(Arrays.asList(label));
    }
}
