package it.uniroma2.sag.kelp.examples.demo.qc;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import java.io.File;
import java.util.Iterator;
import java.util.List;
import org.slf4j.impl.SimpleLogger;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/qc/QuestionClassificationLearningFromJson.class */
public class QuestionClassificationLearningFromJson {
    public static void main(String[] strArr) {
        try {
            System.setProperty(SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "WARN");
            SimpleDataset simpleDataset = new SimpleDataset();
            simpleDataset.populate("src/main/resources/qc/train_5500.coarse.klp.gz");
            SimpleDataset simpleDataset2 = new SimpleDataset();
            simpleDataset2.populate("src/main/resources/qc/TREC_10.coarse.klp.gz");
            System.out.println("Training set statistics");
            System.out.print("Examples number ");
            System.out.println(simpleDataset.getNumberOfExamples());
            List<Label> classificationLabels = simpleDataset.getClassificationLabels();
            for (Label label : classificationLabels) {
                System.out.println("Training Label " + label.toString() + " " + simpleDataset.getNumberOfPositiveExamples(label));
                System.out.println("Training Label " + label.toString() + " " + simpleDataset.getNumberOfNegativeExamples(label));
                System.out.println("Test Label " + label.toString() + " " + simpleDataset2.getNumberOfPositiveExamples(label));
                System.out.println("Test Label " + label.toString() + " " + simpleDataset2.getNumberOfNegativeExamples(label));
            }
            OneVsAllLearning oneVsAllLearning = (OneVsAllLearning) new JacksonSerializerWrapper().readValue(new File("src/main/resources/qc/learningAlgorithmSpecification.klp"), OneVsAllLearning.class);
            oneVsAllLearning.setLabels(classificationLabels);
            oneVsAllLearning.learn(simpleDataset);
            OneVsAllClassifier predictionFunction = oneVsAllLearning.getPredictionFunction();
            MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator(classificationLabels);
            Iterator<Example> it2 = simpleDataset2.getExamples().iterator();
            while (it2.hasNext()) {
                multiclassClassificationEvaluator.addCount(it2.next(), predictionFunction.predict(simpleDataset2.getNextExample()));
            }
            System.out.println("Accuracy: " + multiclassClassificationEvaluator.getAccuracy());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
