package it.uniroma2.sag.kelp.examples.demo.qc;

import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.data.label.Label;
import it.uniroma2.sag.kelp.data.manipulator.LexicalStructureElementManipulator;
import it.uniroma2.sag.kelp.data.manipulator.Manipulator;
import it.uniroma2.sag.kelp.data.representation.structure.similarity.LexicalStructureElementSimilarity;
import it.uniroma2.sag.kelp.data.representation.structure.similarity.compositional.sum.CompositionalNodeSimilaritySum;
import it.uniroma2.sag.kelp.kernel.Kernel;
import it.uniroma2.sag.kelp.kernel.cache.FixIndexKernelCache;
import it.uniroma2.sag.kelp.kernel.cache.FixIndexSquaredNormCache;
import it.uniroma2.sag.kelp.kernel.standard.NormalizationKernel;
import it.uniroma2.sag.kelp.kernel.tree.PartialTreeKernel;
import it.uniroma2.sag.kelp.kernel.tree.SmoothedPartialTreeKernel;
import it.uniroma2.sag.kelp.kernel.tree.SubSetTreeKernel;
import it.uniroma2.sag.kelp.kernel.vector.LinearKernel;
import it.uniroma2.sag.kelp.learningalgorithm.classification.libsvm.BinaryCSvmClassification;
import it.uniroma2.sag.kelp.learningalgorithm.classification.multiclassification.OneVsAllLearning;
import it.uniroma2.sag.kelp.predictionfunction.classifier.ClassificationOutput;
import it.uniroma2.sag.kelp.predictionfunction.classifier.multiclass.OneVsAllClassifier;
import it.uniroma2.sag.kelp.utils.JacksonSerializerWrapper;
import it.uniroma2.sag.kelp.utils.evaluation.MulticlassClassificationEvaluator;
import it.uniroma2.sag.kelp.wordspace.Wordspace;
import java.io.IOException;
import org.slf4j.impl.SimpleLogger;

/* loaded from: input_file:it/uniroma2/sag/kelp/examples/demo/qc/QuestionClassification.class */
public class QuestionClassification {
    public static void main(String[] strArr) {
        try {
            System.setProperty(SimpleLogger.DEFAULT_LOG_LEVEL_KEY, "INFO");
            SimpleDataset simpleDataset = new SimpleDataset();
            simpleDataset.populate("src/main/resources/qc/train_5500.coarse.klp.gz");
            SimpleDataset simpleDataset2 = new SimpleDataset();
            simpleDataset2.populate("src/main/resources/qc/TREC_10.coarse.klp.gz");
            System.out.println("Training set statistics");
            System.out.print("Examples number ");
            System.out.println(simpleDataset.getNumberOfExamples());
            for (Label label : simpleDataset.getClassificationLabels()) {
                System.out.println("Training Label " + label.toString() + " " + simpleDataset.getNumberOfPositiveExamples(label));
                System.out.println("Training Label " + label.toString() + " " + simpleDataset.getNumberOfNegativeExamples(label));
                System.out.println("Test Label " + label.toString() + " " + simpleDataset2.getNumberOfPositiveExamples(label));
                System.out.println("Test Label " + label.toString() + " " + simpleDataset2.getNumberOfNegativeExamples(label));
            }
            Kernel qCKernelFunction = getQCKernelFunction(simpleDataset, simpleDataset2, "csptk");
            JacksonSerializerWrapper jacksonSerializerWrapper = new JacksonSerializerWrapper();
            BinaryCSvmClassification binaryCSvmClassification = new BinaryCSvmClassification();
            binaryCSvmClassification.setKernel(qCKernelFunction);
            binaryCSvmClassification.setCn(3.0f);
            binaryCSvmClassification.setFairness(true);
            OneVsAllLearning oneVsAllLearning = new OneVsAllLearning();
            oneVsAllLearning.setBaseAlgorithm(binaryCSvmClassification);
            oneVsAllLearning.setLabels(simpleDataset.getClassificationLabels());
            jacksonSerializerWrapper.writeValueOnFile(oneVsAllLearning, "src/main/resources/qc/learningAlgorithmSpecificationFromJavaCode.klp");
            oneVsAllLearning.learn(simpleDataset);
            OneVsAllClassifier predictionFunction = oneVsAllLearning.getPredictionFunction();
            jacksonSerializerWrapper.writeValueOnFile(predictionFunction, "src/main/resources/qc/classificationAlgorithm.klp");
            MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator(simpleDataset.getClassificationLabels());
            for (Example example : simpleDataset2.getExamples()) {
                ClassificationOutput predict = predictionFunction.predict(simpleDataset2.getNextExample());
                multiclassClassificationEvaluator.addCount(example, predict);
                System.out.println("Question:\t" + example.getRepresentation("quest"));
                System.out.println("Original class:\t" + example.getClassificationLabels());
                System.out.println("Predicted class:\t" + predict.getPredictedClasses());
                System.out.println();
            }
            System.out.println("Accuracy: " + multiclassClassificationEvaluator.getAccuracy());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static Kernel getQCKernelFunction(SimpleDataset simpleDataset, SimpleDataset simpleDataset2, String str) throws IOException {
        NormalizationKernel normalizationKernel = null;
        int numberOfExamples = simpleDataset.getNumberOfExamples() + simpleDataset2.getNumberOfExamples();
        if (str.equalsIgnoreCase("bow")) {
            LinearKernel linearKernel = new LinearKernel("bow");
            linearKernel.setSquaredNormCache(new FixIndexSquaredNormCache(numberOfExamples));
            normalizationKernel = new NormalizationKernel(linearKernel);
        } else if (str.equalsIgnoreCase("stk")) {
            SubSetTreeKernel subSetTreeKernel = new SubSetTreeKernel(0.4f, "grct");
            subSetTreeKernel.setSquaredNormCache(new FixIndexSquaredNormCache(numberOfExamples));
            normalizationKernel = new NormalizationKernel(subSetTreeKernel);
        } else if (str.equalsIgnoreCase("ptk")) {
            PartialTreeKernel partialTreeKernel = new PartialTreeKernel(0.4f, 0.4f, 5.0f, "grct");
            partialTreeKernel.setSquaredNormCache(new FixIndexSquaredNormCache(numberOfExamples));
            normalizationKernel = new NormalizationKernel(partialTreeKernel);
        } else if (str.equalsIgnoreCase("sptk")) {
            Wordspace wordspace = new Wordspace("src/main/resources/wordspace/wordspace_qc.txt.gz");
            Manipulator lexicalStructureElementManipulator = new LexicalStructureElementManipulator(wordspace, "lct");
            simpleDataset.manipulate(lexicalStructureElementManipulator);
            simpleDataset2.manipulate(lexicalStructureElementManipulator);
            SmoothedPartialTreeKernel smoothedPartialTreeKernel = new SmoothedPartialTreeKernel(0.4f, 0.4f, 0.2f, 0.01f, new LexicalStructureElementSimilarity(wordspace), "lct");
            smoothedPartialTreeKernel.setSquaredNormCache(new FixIndexSquaredNormCache(numberOfExamples));
            normalizationKernel = new NormalizationKernel(smoothedPartialTreeKernel);
        } else if (str.equalsIgnoreCase("csptk")) {
            Wordspace wordspace2 = new Wordspace("src/main/resources/wordspace/wordspace_qc.txt.gz");
            Manipulator lexicalStructureElementManipulator2 = new LexicalStructureElementManipulator(wordspace2, "clct");
            simpleDataset.manipulate(lexicalStructureElementManipulator2);
            simpleDataset2.manipulate(lexicalStructureElementManipulator2);
            CompositionalNodeSimilaritySum compositionalNodeSimilaritySum = new CompositionalNodeSimilaritySum();
            compositionalNodeSimilaritySum.setWordspace(wordspace2);
            compositionalNodeSimilaritySum.setRepresentationToBeEnriched("clct");
            simpleDataset.manipulate(compositionalNodeSimilaritySum);
            simpleDataset2.manipulate(compositionalNodeSimilaritySum);
            SmoothedPartialTreeKernel smoothedPartialTreeKernel2 = new SmoothedPartialTreeKernel(0.4f, 0.4f, 1.0f, 0.01f, compositionalNodeSimilaritySum, "clct");
            smoothedPartialTreeKernel2.setSquaredNormCache(new FixIndexSquaredNormCache(numberOfExamples));
            normalizationKernel = new NormalizationKernel(smoothedPartialTreeKernel2);
        } else {
            System.err.println("The kernel " + str + " has not been defined.");
        }
        normalizationKernel.setKernelCache(new FixIndexKernelCache(numberOfExamples));
        return normalizationKernel;
    }
}
