package it.uniroma2.sag.kelp.learningalgorithm.classification.passiveaggressive;

import com.fasterxml.jackson.annotation.JsonTypeName;
import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryKernelMachineClassifier;
import it.uniroma2.sag.kelp.predictionfunction.classifier.BinaryMarginClassifierOutput;
import it.uniroma2.sag.kelp.predictionfunction.model.BinaryKernelMachineModel;
import it.uniroma2.sag.kelp.predictionfunction.model.SupportVector;
import java.util.List;
import org.ejml.alg.dense.mult.VectorVectorMult;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

@JsonTypeName("budgetedPA")
/* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/BudgetedPassiveAggressiveClassification.class */
public class BudgetedPassiveAggressiveClassification extends BudgetedLearningAlgorithm {
    private DeletingPolicy deletingPolicy;
    private Kernel kernel;
    private BinaryKernelMachineClassifier classifier;
    private boolean fairness;
    private float cp;
    private float cn;
    private boolean areNnComputed;
    private int[] nearestNeighbors;
    private float[] nearestNeighborsSimilarity;
    private DenseMatrix64F krVector;
    private DenseMatrix64F beta;
    private DenseMatrix64F kMatrix;
    private DenseMatrix64F ktVector;
    private DenseMatrix64F Kkr;
    private DenseMatrix64F Kkt;

    /* loaded from: input_file:it/uniroma2/sag/kelp/learningalgorithm/classification/passiveaggressive/BudgetedPassiveAggressiveClassification$DeletingPolicy.class */
    public enum DeletingPolicy {
        BPA_S,
        BPA_1NN
    }

    public BudgetedPassiveAggressiveClassification() {
        this.deletingPolicy = DeletingPolicy.BPA_S;
        this.fairness = false;
        this.cp = 1.0f;
        this.cn = 1.0f;
        this.areNnComputed = false;
        this.nearestNeighbors = null;
        this.nearestNeighborsSimilarity = null;
        this.krVector = new DenseMatrix64F(2, 1);
        this.beta = new DenseMatrix64F(2, 1);
        this.kMatrix = new DenseMatrix64F(2, 2);
        this.ktVector = new DenseMatrix64F(2, 1);
        this.Kkr = new DenseMatrix64F(2, 1);
        this.Kkt = new DenseMatrix64F(2, 1);
        this.classifier = new BinaryKernelMachineClassifier();
        this.classifier.setModel(new BinaryKernelMachineModel());
    }

    public BudgetedPassiveAggressiveClassification(int i, Kernel kernel, float f, float f2, DeletingPolicy deletingPolicy, Label label) {
        this();
        setDeletingPolicy(deletingPolicy);
        setCn(f2);
        setCp(f);
        setKernel(kernel);
        setBudget(i);
        setLabel(label);
    }

    public BudgetedPassiveAggressiveClassification(int i, Kernel kernel, float f, boolean z, DeletingPolicy deletingPolicy, Label label) {
        this();
        setDeletingPolicy(deletingPolicy);
        setCn(f);
        setCp(f);
        setFairness(z);
        setKernel(kernel);
        setBudget(i);
        setLabel(label);
    }

    public boolean isFairness() {
        return this.fairness;
    }

    public void setFairness(boolean z) {
        this.fairness = z;
    }

    public float getCp() {
        return this.cp;
    }

    public void setCp(float f) {
        this.cp = f;
    }

    public float getCn() {
        return this.cn;
    }

    public void setCn(float f) {
        this.cn = f;
    }

    public void setC(float f) {
        this.cn = f;
        this.cp = f;
    }

    public DeletingPolicy getDeletingPolicy() {
        return this.deletingPolicy;
    }

    public void setDeletingPolicy(DeletingPolicy deletingPolicy) {
        this.deletingPolicy = deletingPolicy;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public LearningAlgorithm duplicate() {
        BudgetedPassiveAggressiveClassification budgetedPassiveAggressiveClassification = new BudgetedPassiveAggressiveClassification(this.budget, this.kernel, this.cp, this.cn, this.deletingPolicy, this.label);
        budgetedPassiveAggressiveClassification.setFairness(this.fairness);
        return budgetedPassiveAggressiveClassification;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void reset() {
        this.classifier.reset();
        this.areNnComputed = false;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public BinaryKernelMachineClassifier getPredictionFunction() {
        return this.classifier;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.KernelMethod
    public Kernel getKernel() {
        return this.kernel;
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.KernelMethod
    public void setKernel(Kernel kernel) {
        this.kernel = kernel;
        getPredictionFunction().getModel().setKernel(kernel);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm
    public BinaryMarginClassifierOutput predictAndLearnWithAvailableBudget(Example example) {
        BinaryMarginClassifierOutput predict = this.classifier.predict(example);
        float evaluateLoss = evaluateLoss(predict.getScore(this.label).floatValue(), example);
        if (evaluateLoss > 0.0f) {
            float f = this.cn;
            if (example.isExampleOf(this.label)) {
                f = this.cp;
            }
            float computeWeight = computeWeight(example, evaluateLoss, this.classifier.getModel().getSquaredNorm(example), f);
            if (!example.isExampleOf(this.label)) {
                computeWeight *= -1.0f;
            }
            getPredictionFunction().getModel().addExample(computeWeight, example);
        }
        return predict;
    }

    private float computeWeight(Example example, float f, float f2, float f3) {
        float f4 = f / f2;
        if (f4 > f3) {
            f4 = f3;
        }
        return f4;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm
    public BinaryMarginClassifierOutput predictAndLearnWithFullBudget(Example example) {
        BinaryMarginClassifierOutput predict = this.classifier.predict(example);
        float evaluateLoss = evaluateLoss(predict.getScore(this.label).floatValue(), example);
        if (evaluateLoss > 0.0f) {
            float f = this.cn;
            if (example.isExampleOf(this.label)) {
                f = this.cp;
            }
            switch (this.deletingPolicy) {
                case BPA_S:
                    bpaSDeletingPolicy(example, f, evaluateLoss, predict.getScore(this.label).floatValue());
                    break;
                case BPA_1NN:
                    bpa1NnDeletingPolicy(example, f, evaluateLoss, predict.getScore(this.label).floatValue());
                    break;
            }
        }
        return predict;
    }

    private float evaluateLoss(float f, Example example) {
        float f2 = 0.0f;
        if ((f > 0.0f) != example.isExampleOf(this.label)) {
            f2 = 1.0f + Math.abs(f);
        } else if (Math.abs(f) < 1.0f) {
            f2 = 1.0f - Math.abs(f);
        }
        return f2;
    }

    private void bpaSDeletingPolicy(Example example, float f, float f2, float f3) {
        float squaredNorm = this.kernel.squaredNorm(example);
        float f4 = f * f2;
        float f5 = 0.0f;
        int i = 0;
        SupportVector supportVector = null;
        int i2 = 0;
        for (SupportVector supportVector2 : this.classifier.getModel().getSupportVectors()) {
            float weight = (supportVector2.getWeight() * this.kernel.innerProduct(example, supportVector2.getInstance())) / squaredNorm;
            float f6 = f2 / squaredNorm;
            if (f6 > f) {
                f6 = f;
            }
            float f7 = example.isExampleOf(this.label) ? weight + f6 : weight - f6;
            float evaluateObjectiveFunctionInBpaS = evaluateObjectiveFunctionInBpaS(example, f7, f3, supportVector2, f);
            if (evaluateObjectiveFunctionInBpaS < f4) {
                i = i2;
                supportVector = supportVector2;
                f5 = f7;
                f4 = evaluateObjectiveFunctionInBpaS;
            }
            i2++;
        }
        if (supportVector != null) {
            getPredictionFunction().getModel().substituteSupportVector(i, example, f5);
        }
    }

    private void bpa1NnDeletingPolicy(Example example, float f, float f2, float f3) {
        float squaredNorm = this.kernel.squaredNorm(example);
        float f4 = f * f2;
        float f5 = 0.0f;
        int i = 0;
        float f6 = 0.0f;
        int i2 = 0;
        SupportVector supportVector = null;
        List<SupportVector> supportVectors = getPredictionFunction().getModel().getSupportVectors();
        int i3 = 0;
        for (SupportVector supportVector2 : supportVectors) {
            int nearestNeighborIndex = getNearestNeighborIndex(i3);
            Example supportVector3 = supportVectors.get(nearestNeighborIndex).getInstance();
            float innerProduct = this.kernel.innerProduct(supportVector3, example);
            this.kMatrix.set(0, 0, this.kernel.squaredNorm(supportVector3));
            this.kMatrix.set(0, 1, innerProduct);
            this.kMatrix.set(1, 0, innerProduct);
            this.kMatrix.set(1, 1, squaredNorm);
            float innerProduct2 = this.kernel.innerProduct(supportVector2.getInstance(), supportVector3);
            float innerProduct3 = this.kernel.innerProduct(supportVector2.getInstance(), example);
            this.krVector.set(0, 0, innerProduct2);
            this.krVector.set(1, 0, innerProduct3);
            this.ktVector.set(0, 0, innerProduct);
            this.ktVector.set(1, 0, squaredNorm);
            CommonOps.invert(this.kMatrix);
            CommonOps.mult(this.kMatrix, this.krVector, this.Kkr);
            CommonOps.mult(this.kMatrix, this.ktVector, this.Kkt);
            float weight = (1.0f - ((example.isExampleOf(this.label) ? 1.0f : -1.0f) * ((f3 - (supportVector2.getWeight() * innerProduct3)) + (supportVector2.getWeight() * ((float) VectorVectorMult.innerProd(this.Kkr, this.ktVector)))))) / ((float) VectorVectorMult.innerProd(this.Kkt, this.ktVector));
            if (weight < 0.0f) {
                weight = 0.0f;
            } else if (weight > f) {
                weight = f;
            }
            CommonOps.add(supportVector2.getWeight(), this.Kkr, weight * r32, this.Kkt, this.beta);
            float f7 = (float) this.beta.get(0, 0);
            float f8 = (float) this.beta.get(1, 0);
            float evaluateObjectiveFunctionInBpa1nn = evaluateObjectiveFunctionInBpa1nn(example, f8, f3, supportVector2, supportVector3, f7, f);
            if (evaluateObjectiveFunctionInBpa1nn < f4) {
                i2 = i3;
                supportVector = supportVector2;
                f5 = f8;
                f4 = evaluateObjectiveFunctionInBpa1nn;
                f6 = f7;
                i = nearestNeighborIndex;
            }
            i3++;
        }
        if (supportVector != null) {
            getPredictionFunction().getModel().substituteSupportVector(i2, example, f5);
            supportVectors.get(i).setWeight(f6 + supportVectors.get(i).getWeight());
            updateNearestNeighbors(i);
        }
    }

    private float evaluateObjectiveFunctionInBpaS(Example example, float f, float f2, SupportVector supportVector, float f3) {
        return (0.5f * ((((supportVector.getWeight() * supportVector.getWeight()) * this.kernel.squaredNorm(supportVector.getInstance())) + ((f * f) * this.kernel.squaredNorm(example))) - (((2.0f * supportVector.getWeight()) * f) * this.kernel.innerProduct(example, supportVector.getInstance())))) + (f3 * evaluateLoss((f2 + (f * this.kernel.squaredNorm(example))) - (supportVector.getWeight() * this.kernel.innerProduct(example, supportVector.getInstance())), example));
    }

    private float evaluateObjectiveFunctionInBpa1nn(Example example, float f, float f2, SupportVector supportVector, Example example2, float f3, float f4) {
        return (0.5f * (((((((supportVector.getWeight() * supportVector.getWeight()) * this.kernel.squaredNorm(supportVector.getInstance())) + ((f * f) * this.kernel.squaredNorm(example))) + ((f3 * f3) * this.kernel.squaredNorm(example2))) - (((2.0f * supportVector.getWeight()) * f) * this.kernel.innerProduct(example, supportVector.getInstance()))) - (((2.0f * supportVector.getWeight()) * f3) * this.kernel.innerProduct(supportVector.getInstance(), example2))) + (2.0f * f * f3 * this.kernel.innerProduct(example, example2)))) + (f4 * evaluateLoss(((f2 + (f * this.kernel.squaredNorm(example))) - (supportVector.getWeight() * this.kernel.innerProduct(example, supportVector.getInstance()))) + (f3 * this.kernel.innerProduct(example, example2)), example));
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm, it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm
    public void learn(Dataset dataset) {
        if (this.fairness) {
            this.cp = (this.cn * dataset.getNumberOfNegativeExamples(this.label)) / dataset.getNumberOfPositiveExamples(this.label);
        }
        super.learn(dataset);
    }

    @Override // it.uniroma2.sag.kelp.learningalgorithm.budgetedAlgorithm.BudgetedLearningAlgorithm, it.uniroma2.sag.kelp.learningalgorithm.OnlineLearningAlgorithm
    public BinaryMarginClassifierOutput learn(Example example) {
        return (BinaryMarginClassifierOutput) super.learn(example);
    }

    private void computeNearestNeighbors() {
        if (this.nearestNeighbors == null) {
            this.nearestNeighbors = new int[this.budget];
            this.nearestNeighborsSimilarity = new float[this.budget];
        }
        for (SupportVector supportVector : getPredictionFunction().getModel().getSupportVectors()) {
            int i = -1;
            float f = Float.NEGATIVE_INFINITY;
            int i2 = 0;
            for (SupportVector supportVector2 : getPredictionFunction().getModel().getSupportVectors()) {
                if (supportVector != supportVector2) {
                    float innerProduct = this.kernel.innerProduct(supportVector.getInstance(), supportVector2.getInstance());
                    if (innerProduct > f) {
                        f = innerProduct;
                        i = i2;
                    }
                }
                i2++;
            }
            this.nearestNeighbors[0] = i;
            this.nearestNeighborsSimilarity[0] = f;
        }
        this.areNnComputed = true;
    }

    private int getNearestNeighborIndex(int i) {
        if (!this.areNnComputed) {
            computeNearestNeighbors();
        }
        return this.nearestNeighbors[i];
    }

    private void updateNearestNeighbors(int i) {
        List<SupportVector> supportVectors = getPredictionFunction().getModel().getSupportVectors();
        Example supportVector = supportVectors.get(i).getInstance();
        int i2 = -1;
        float f = Float.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < this.budget; i3++) {
            if (i3 != i) {
                Example supportVector2 = supportVectors.get(i3).getInstance();
                float innerProduct = this.kernel.innerProduct(supportVector, supportVector2);
                if (innerProduct > f) {
                    f = innerProduct;
                    i2 = i3;
                }
                if (this.nearestNeighbors[i3] == i) {
                    int i4 = -1;
                    float f2 = Float.NEGATIVE_INFINITY;
                    for (int i5 = 0; i5 < this.budget; i5++) {
                        if (i5 != i3) {
                            float innerProduct2 = this.kernel.innerProduct(supportVectors.get(i5).getInstance(), supportVector2);
                            if (innerProduct2 > f2) {
                                f2 = innerProduct2;
                                i4 = i5;
                            }
                        }
                    }
                    this.nearestNeighbors[i3] = i4;
                    this.nearestNeighborsSimilarity[i3] = f2;
                } else if (innerProduct > this.nearestNeighborsSimilarity[i3]) {
                    this.nearestNeighborsSimilarity[i3] = innerProduct;
                    this.nearestNeighbors[i3] = i;
                }
            }
        }
        this.nearestNeighbors[i] = i2;
        this.nearestNeighborsSimilarity[i] = f;
    }
}
