package it.uniroma2.sag.kelp.learningalgorithm.classification.dcd;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.data.representation.Vector;
import it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.LinearMethod;
import it.uniroma2.sag.kelp.learningalgorithm.classification.ClassificationLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryLinearClassifier;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryLinearModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

@JsonTypeName("dcd")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/dcd/DCDLearningAlgorithm.class */
public class DCDLearningAlgorithm implements LinearMethod, ClassificationLearningAlgorithm, BinaryLearningAlgorithm {
    private Label label;

    @JsonIgnore
    private BinaryLinearClassifier classifier;
    private String representation;
    private DCDLoss dcdLoss;
    private boolean useBias;
    private int maxIterations;
    private boolean fairness;
    private double cp;
    private double cn;
    private long seed;

    public DCDLearningAlgorithm() {
        this.seed = 0L;
        this.dcdLoss = DCDLoss.L2;
        this.useBias = false;
        this.classifier = new BinaryLinearClassifier();
        this.classifier.setModel(new BinaryLinearModel());
    }

    public DCDLearningAlgorithm(Label label, double d, double d2, DCDLoss dCDLoss, boolean z, int i, String str) {
        this();
        this.label = label;
        this.maxIterations = i;
        this.cp = d;
        this.cn = d2;
        this.dcdLoss = dCDLoss;
        this.useBias = z;
        this.representation = str;
    }

    public DCDLearningAlgorithm(double d, double d2, DCDLoss dCDLoss, boolean z, int i, String str) {
        this();
        this.maxIterations = i;
        this.cp = d;
        this.cn = d2;
        this.dcdLoss = dCDLoss;
        this.useBias = z;
        this.representation = str;
    }

    public DCDLearningAlgorithm(double d, double d2, int i, String str) {
        this();
        this.maxIterations = i;
        this.cp = d;
        this.cn = d2;
        this.representation = str;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public DCDLearningAlgorithm duplicate() {
        DCDLearningAlgorithm dCDLearningAlgorithm = new DCDLearningAlgorithm();
        dCDLearningAlgorithm.setRepresentation(this.representation);
        dCDLearningAlgorithm.setCp(this.cp);
        dCDLearningAlgorithm.setCn(this.cn);
        dCDLearningAlgorithm.setFairness(this.fairness);
        dCDLearningAlgorithm.setMaxIterations(this.maxIterations);
        dCDLearningAlgorithm.setUseBias(this.useBias);
        dCDLearningAlgorithm.setDcdLoss(this.dcdLoss);
        return dCDLearningAlgorithm;
    }

    public double getCn() {
        return this.cn;
    }

    public double getCp() {
        return this.cp;
    }

    private double getD(Example example) {
        return this.dcdLoss == DCDLoss.L1 ? CMAESOptimizer.DEFAULT_STOPFITNESS : example.isExampleOf(this.label) ? 1.0d / (2.0d * this.cp) : 1.0d / (2.0d * this.cn);
    }

    public DCDLoss getDcdLoss() {
        return this.dcdLoss;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public Label getLabel() {
        return this.label;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public List<Label> getLabels() {
        return Arrays.asList(this.label);
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryLinearClassifier getPredictionFunction() {
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LinearMethod
    public String getRepresentation() {
        return this.representation;
    }

    public long getSeed() {
        return this.seed;
    }

    private double getU(Example example) {
        if (this.dcdLoss == DCDLoss.L1) {
            return example.isExampleOf(this.label) ? this.cp : this.cn;
        }
        return Double.POSITIVE_INFINITY;
    }

    private float getY(Example example) {
        return example.isExampleOf(this.label) ? 1.0f : -1.0f;
    }

    public boolean isFairness() {
        return this.fairness;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        if (dataset.getNumberOfExamples() == 0) {
            return;
        }
        if (isFairness() && this.cp == this.cn) {
            this.cp = (this.cn * dataset.getNumberOfNegativeExamples(this.label)) / dataset.getNumberOfPositiveExamples(this.label);
        }
        List<Example> examples = dataset.getExamples();
        float[] fArr = new float[examples.size()];
        float[] fArr2 = new float[examples.size()];
        float f = 0.0f;
        double[] dArr = new double[examples.size()];
        double[] dArr2 = new double[examples.size()];
        double[] dArr3 = new double[examples.size()];
        for (int i = 0; i < dataset.getNumberOfExamples(); i++) {
            Example example = dataset.getExamples().get(i);
            fArr2[i] = getY(example);
            dArr2[i] = getU(example);
            dArr3[i] = getD(example);
            Vector vector = (Vector) dataset.getExamples().get(i).getRepresentation(this.representation);
            dArr[i] = vector.innerProduct(vector) + dArr3[i];
            if (this.useBias) {
                int i2 = i;
                dArr[i2] = dArr[i2] + 1.0d;
            }
        }
        if (getPredictionFunction().getModel().getHyperplane() == null) {
            getPredictionFunction().getModel().setHyperplane(dataset.getZeroVector(this.representation));
        }
        Vector hyperplane = getPredictionFunction().getModel().getHyperplane();
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < dataset.getNumberOfExamples(); i3++) {
            arrayList.add(Integer.valueOf(i3));
        }
        Random random = new Random(this.seed);
        for (int i4 = 0; i4 < this.maxIterations; i4++) {
            Collections.shuffle(arrayList, random);
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                int intValue = ((Integer) it2.next()).intValue();
                Vector vector2 = (Vector) dataset.getExamples().get(intValue).getRepresentation(this.representation);
                double innerProduct = ((fArr2[intValue] * (hyperplane.innerProduct(vector2) + f)) - 1.0f) + (dArr3[intValue] * fArr[intValue]);
                if ((fArr[intValue] == 0.0f ? Math.min(innerProduct, CMAESOptimizer.DEFAULT_STOPFITNESS) : ((double) fArr[intValue]) == dArr2[intValue] ? Math.max(innerProduct, CMAESOptimizer.DEFAULT_STOPFITNESS) : innerProduct) != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    float f2 = fArr[intValue];
                    fArr[intValue] = (float) Math.min(Math.max(fArr[intValue] - (innerProduct / dArr[intValue]), CMAESOptimizer.DEFAULT_STOPFITNESS), dArr2[intValue]);
                    float f3 = (fArr[intValue] - f2) * fArr2[intValue];
                    hyperplane.add(f3, vector2);
                    if (this.useBias) {
                        f += f3;
                    }
                }
            }
        }
        this.classifier.getModel().setHyperplane(hyperplane);
        this.classifier.getModel().setRepresentation(this.representation);
        if (this.useBias) {
            this.classifier.getModel().setBias(f);
        } else {
            this.classifier.getModel().setBias(0.0f);
        }
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.classifier.reset();
    }

    public void setCn(double d) {
        this.cn = d;
    }

    public void setCp(double d) {
        this.cp = d;
    }

    public void setDcdLoss(DCDLoss dCDLoss) {
        this.dcdLoss = dCDLoss;
    }

    public void setFairness(boolean z) {
        this.fairness = z;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.BinaryLearningAlgorithm
    public void setLabel(Label label) {
        setLabels(Arrays.asList(label));
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void setLabels(List<Label> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("LibLinear algorithm is a binary method which can learn a single Label");
        }
        this.label = list.get(0);
        this.classifier.setLabels(list);
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LinearMethod
    public void setRepresentation(String str) {
        this.representation = str;
        this.classifier.getModel().setRepresentation(str);
    }

    public void setSeed(long j) {
        this.seed = j;
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }
}
