package it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass;

import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryModel;
import it.uniroma2.sag.kelp.predictionfunction.model.Model;
import it.uniroma2.sag.kelp.predictionfunction.model.MulticlassModel;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonTypeName("oneVsOneClassifier")
/* loaded from: input_file:it/uniroma2/sag/kelp/predictionfunction/classifier/multiclass/OneVsOneClassifier.class */
public class OneVsOneClassifier implements Classifier {
    private Classifier[] binaryClassifiers;
    private Label[] negativeLabelsForClassifier;
    private List<Label> labels;
    private Logger logger = LoggerFactory.getLogger(OneVsOneClassifier.class);
    private MulticlassModel model = new MulticlassModel();

    public Label[] getNegativeLabelsForClassifier() {
        return this.negativeLabelsForClassifier;
    }

    public void setNegativeLabelsForClassifier(Label[] labelArr) {
        this.negativeLabelsForClassifier = labelArr;
    }

    public Classifier[] getBinaryClassifiers() {
        return this.binaryClassifiers;
    }

    public void setBinaryClassifiers(Classifier[] classifierArr) {
        this.binaryClassifiers = classifierArr;
        ArrayList arrayList = new ArrayList();
        for (Classifier classifier : classifierArr) {
            arrayList.add((BinaryModel) classifier.getModel());
        }
        this.model.setModels(arrayList);
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public OneVsOneClassificationOutput predict(Example example) {
        OneVsOneClassificationOutput oneVsOneClassificationOutput = new OneVsOneClassificationOutput();
        this.logger.debug("----------------");
        for (int i = 0; i < this.binaryClassifiers.length; i++) {
            ClassificationOutput predict = this.binaryClassifiers[i].predict(example);
            Label label = predict.getAllClasses().get(0);
            Float score = predict.getScore(label);
            this.logger.debug(label.toString() + " vs " + this.negativeLabelsForClassifier[i].toString() + ": " + score);
            oneVsOneClassificationOutput.addVotedPrediction(score.floatValue() >= 0.0f ? label : this.negativeLabelsForClassifier[i], score.floatValue());
        }
        return oneVsOneClassificationOutput;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void reset() {
        for (Classifier classifier : this.binaryClassifiers) {
            classifier.reset();
        }
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void setLabels(List<Label> list) {
        this.labels = list;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public List<Label> getLabels() {
        return this.labels;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public MulticlassModel getModel() {
        return this.model;
    }

    @Override // it.uniroma2.sag.kelp.predictionfunction.PredictionFunction
    public void setModel(Model model) {
        this.model = (MulticlassModel) model;
    }
}
