package it.uniroma2.sag.kelp.utils.evaluation;

import gnu.trove.map.hash.TObjectFloatHashMap;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.Prediction;
import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/utils/evaluation/MulticlassClassificationEvaluator.class */
public class MulticlassClassificationEvaluator extends Evaluator {
    private List<Label> labels;
    protected TObjectFloatHashMap<Label> correctCounter = new TObjectFloatHashMap<>();
    protected TObjectFloatHashMap<Label> predictedCounter = new TObjectFloatHashMap<>();
    protected TObjectFloatHashMap<Label> toBePredictedCounter = new TObjectFloatHashMap<>();
    private TObjectFloatHashMap<Label> precisions = new TObjectFloatHashMap<>();
    private TObjectFloatHashMap<Label> recalls = new TObjectFloatHashMap<>();
    private TObjectFloatHashMap<Label> f1s = new TObjectFloatHashMap<>();
    protected int total;
    protected int correct;
    private float accuracy;
    private float overallPrecision;
    private float overallRecall;
    private float overallF1;

    public MulticlassClassificationEvaluator(List<Label> list) {
        this.labels = list;
        initializeCounters();
    }

    private void initializeCounters() {
        for (Label label : this.labels) {
            this.correctCounter.put(label, 0.0f);
            this.toBePredictedCounter.put(label, 0.0f);
            this.predictedCounter.put(label, 0.0f);
            this.precisions.put(label, -1.0f);
            this.recalls.put(label, -1.0f);
            this.f1s.put(label, -1.0f);
        }
        this.total = 0;
        this.correct = 0;
        this.accuracy = 0.0f;
        this.computed = false;
    }

    public TObjectFloatHashMap<Label> getPrecisions() {
        if (!this.computed) {
            compute();
        }
        return this.precisions;
    }

    public TObjectFloatHashMap<Label> getRecalls() {
        if (!this.computed) {
            compute();
        }
        return this.recalls;
    }

    public TObjectFloatHashMap<Label> getF1s() {
        if (!this.computed) {
            compute();
        }
        return this.f1s;
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public void addCount(Example example, Prediction prediction) {
        ClassificationOutput classificationOutput = (ClassificationOutput) prediction;
        Iterator<Label> it2 = example.getClassificationLabels().iterator();
        while (it2.hasNext()) {
            Label next = it2.next();
            this.toBePredictedCounter.put(next, this.toBePredictedCounter.get(next) + 1.0f);
            this.total++;
        }
        List<Label> predictedClasses = classificationOutput.getPredictedClasses();
        if (predictedClasses.size() > 0) {
            for (Label label : predictedClasses) {
                this.predictedCounter.put(label, this.predictedCounter.get(label) + 1.0f);
                if (example.isExampleOf(label)) {
                    this.correctCounter.put(label, this.correctCounter.get(label) + 1.0f);
                    this.correct++;
                }
            }
        }
        this.computed = false;
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    protected void compute() {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (Label label : this.labels) {
            float f = 0.0f;
            float f2 = 0.0f;
            float f3 = 0.0f;
            if (this.correctCounter.get(label) != 0.0f && this.predictedCounter.get(label) != 0.0f && this.toBePredictedCounter.get(label) != 0.0f) {
                f = this.correctCounter.get(label) / this.predictedCounter.get(label);
                f2 = this.correctCounter.get(label) / this.toBePredictedCounter.get(label);
                f3 = ((2.0f * f) * f2) / (f + f2);
            }
            i = (int) (i + this.correctCounter.get(label));
            i2 = (int) (i2 + this.predictedCounter.get(label));
            i3 = (int) (i3 + this.toBePredictedCounter.get(label));
            this.precisions.put(label, f);
            this.recalls.put(label, f2);
            this.f1s.put(label, f3);
        }
        this.overallPrecision = i / i2;
        this.overallRecall = i / i3;
        this.overallF1 = ((2.0f * this.overallPrecision) * this.overallRecall) / (this.overallPrecision + this.overallRecall);
        this.accuracy = this.correct / this.total;
        this.computed = true;
    }

    public float getPrecisionFor(Label label) {
        if (!this.computed) {
            compute();
        }
        if (this.precisions.containsKey(label)) {
            return this.precisions.get(label);
        }
        return -1.0f;
    }

    public float getRecallFor(Label label) {
        if (!this.computed) {
            compute();
        }
        if (this.recalls.containsKey(label)) {
            return this.recalls.get(label);
        }
        return -1.0f;
    }

    public float getF1For(Label label) {
        if (!this.computed) {
            compute();
        }
        if (this.f1s.containsKey(label)) {
            return this.f1s.get(label);
        }
        return -1.0f;
    }

    public float getAccuracy() {
        if (!this.computed) {
            compute();
        }
        return this.accuracy;
    }

    public float getOverallPrecision() {
        if (!this.computed) {
            compute();
        }
        return this.overallPrecision;
    }

    public float getOverallRecall() {
        if (!this.computed) {
            compute();
        }
        return this.overallRecall;
    }

    public float getOverallF1() {
        if (!this.computed) {
            compute();
        }
        return this.overallF1;
    }

    public float getMeanF1() {
        if (!this.computed) {
            compute();
        }
        return getMeanF1For((ArrayList) this.labels);
    }

    public float getMeanF1For(ArrayList<Label> arrayList) {
        if (!this.computed) {
            compute();
        }
        float f = 0.0f;
        Iterator<Label> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            f += this.f1s.get(it2.next());
        }
        return f / arrayList.size();
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public void clear() {
        this.correctCounter.clear();
        this.predictedCounter.clear();
        this.toBePredictedCounter.clear();
        this.precisions.clear();
        this.recalls.clear();
        this.f1s.clear();
        this.accuracy = 0.0f;
        this.computed = false;
    }

    private void printCounters() {
        for (Label label : this.labels) {
            System.out.println(label);
            System.out.print("\t");
            printCounters(label);
        }
    }

    public void printCounters(Label label) {
        System.out.println(this.correctCounter.get(label) + " " + this.predictedCounter.get(label) + " " + this.toBePredictedCounter.get(label));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Label label : this.labels) {
            sb.append(label + "\t" + this.precisions.get(label) + "\t" + this.recalls.get(label) + "\t" + this.f1s.get(label) + "\n");
        }
        return sb.toString().trim();
    }

    @Override // it.uniroma2.sag.kelp.utils.evaluation.Evaluator
    public MulticlassClassificationEvaluator duplicate() {
        return new MulticlassClassificationEvaluator(this.labels);
    }
}
