package it.uniroma2.sag.kelp.examples.demo.nystrom;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.dataset.selector.RandomExampleSelector;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.examples.demo.qc.QuestionClassification;
import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.MultiEpochLearning;
import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLearningAlgorithm;
import it.uniroma2.sag.kelp.learningalgorithm.classification.dcd.DCDLoss;
import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning;
import it.uniroma2.sag.kelp.learningalgorithm.classification.scw.SCWType;
import it.uniroma2.sag.kelp.learningalgorithm.classification.scw.SoftConfidenceWeightedClassification;
import it.uniroma2.sag.kelp.linearization.nystrom.NystromMethod;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import javassist.compiler.TokenId;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/nystrom/NystromExampleMain.class */
public class NystromExampleMain {
    public static final String LINEAR_REP_NAME = "lin";

    public static void main(String[] strArr) throws Exception {
        SimpleDataset simpleDataset = new SimpleDataset();
        simpleDataset.populate("src/main/resources/qc/train_5500.coarse.klp.gz");
        SimpleDataset simpleDataset2 = new SimpleDataset();
        simpleDataset2.populate("src/main/resources/qc/TREC_10.coarse.klp.gz");
        NystromMethod nystromMethod = new NystromMethod(new RandomExampleSelector(TokenId.BadToken, 0).select(simpleDataset), QuestionClassification.getQCKernelFunction(simpleDataset, simpleDataset2, "csptk"));
        SimpleDataset linearizedDataset = nystromMethod.getLinearizedDataset(simpleDataset, LINEAR_REP_NAME);
        SimpleDataset linearizedDataset2 = nystromMethod.getLinearizedDataset(simpleDataset2, LINEAR_REP_NAME);
        linearizedDataset.setSeed(0);
        System.out.println("Batch Learning Accuracy:\t" + evaluateClassifier(linearizedDataset, linearizedDataset2, new DCDLearningAlgorithm(5.0d, 5.0d, DCDLoss.L2, false, 30, LINEAR_REP_NAME)));
        System.out.println("Online Learning Accuracy:\t" + evaluateClassifier(linearizedDataset, linearizedDataset2, new MultiEpochLearning(2, new SoftConfidenceWeightedClassification(null, SCWType.SCW_II, 0.9f, 2.0f, 2.0f, false, LINEAR_REP_NAME))));
    }

    private static float evaluateClassifier(SimpleDataset simpleDataset, SimpleDataset simpleDataset2, LearningAlgorithm learningAlgorithm) {
        OneVsAllLearning oneVsAllLearning = new OneVsAllLearning();
        oneVsAllLearning.setBaseAlgorithm(learningAlgorithm);
        oneVsAllLearning.setLabels(simpleDataset.getClassificationLabels());
        oneVsAllLearning.learn(simpleDataset);
        MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator(simpleDataset.getClassificationLabels());
        OneVsAllClassifier predictionFunction = oneVsAllLearning.getPredictionFunction();
        for (Example example : simpleDataset2.getExamples()) {
            multiclassClassificationEvaluator.addCount(example, predictionFunction.predict(example));
        }
        return multiclassClassificationEvaluator.getAccuracy();
    }
}
