package it.uniroma2.sag.kelp.main;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.predictionfunction.Prediction;
import it.uniroma2.sag.kelp.predictionfunction.classifier.Classifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.evaluation.BinaryClassificationEvaluator;
import it.uniroma2.sag.kelp.utils.evaluation.Evaluator;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import java.io.File;
import java.io.PrintWriter;
import java.util.List;

/* loaded from: input_file:it/uniroma2/sag/kelp/main/Classify.class */
public class Classify {
    public static void main(String[] strArr) throws Exception {
        if (strArr.length < 3) {
            System.out.println("USAGE: datasetPath modelPath predictionsPath");
            System.exit(1);
        }
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        SimpleDataset simpleDataset = new SimpleDataset();
        simpleDataset.populate(str);
        Classifier classifier = (Classifier) new JacksonSerializerWrapper().readValue(new File(str2), Classifier.class);
        List<Label> labels = classifier.getLabels();
        Evaluator multiclassClassificationEvaluator = labels.size() > 1 ? new MulticlassClassificationEvaluator(labels) : new BinaryClassificationEvaluator(labels.get(0));
        PrintWriter printWriter = new PrintWriter(str3, "utf8");
        StringBuilder sb = new StringBuilder();
        for (Example example : simpleDataset.getExamples()) {
            sb.delete(0, sb.length());
            Prediction predict = classifier.predict(example);
            multiclassClassificationEvaluator.addCount(example, predict);
            for (Label label : labels) {
                sb.append(label + ":" + predict.getScore(label) + "\t");
            }
            printWriter.println(sb.toString().substring(0, sb.length() - 1));
        }
        printWriter.flush();
        printWriter.close();
        System.out.println("Accuracy on test set: " + multiclassClassificationEvaluator.getPerformanceMeasure("Accuracy", new Object[0]));
    }
}
