package it.uniroma2.sag.kelp.utils;

import it.uniroma2.sag.kelp.data.dataset.Dataset;
import it.uniroma2.sag.kelp.data.dataset.SimpleDataset;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.learningalgorithm.LearningAlgorithm;
import it.uniroma2.sag.kelp.predictionfunction.Prediction;
import it.uniroma2.sag.kelp.predictionfunction.PredictionFunction;
import it.uniroma2.sag.kelp.utils.evaluation.Evaluator;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:it/uniroma2/sag/kelp/utils/ExperimentUtils.class */
public class ExperimentUtils {
    private static Logger logger = LoggerFactory.getLogger(ExperimentUtils.class);

    public static List<Prediction> test(PredictionFunction predictionFunction, Evaluator evaluator, Dataset dataset) {
        ArrayList arrayList = new ArrayList();
        for (Example example : dataset.getExamples()) {
            Prediction predict = predictionFunction.predict(example);
            evaluator.addCount(example, predict);
            arrayList.add(predict);
        }
        return arrayList;
    }

    public static <T extends Evaluator> List<T> nFoldCrossValidation(int i, LearningAlgorithm learningAlgorithm, SimpleDataset simpleDataset, T t) {
        SimpleDataset[] nFoldingClassDistributionInvariant = simpleDataset.nFoldingClassDistributionInvariant(i);
        ArrayList arrayList = new ArrayList();
        arrayList.add(t);
        for (int i2 = 1; i2 < i; i2++) {
            arrayList.add(t.duplicate());
        }
        for (int i3 = 0; i3 < i; i3++) {
            SimpleDataset simpleDataset2 = nFoldingClassDistributionInvariant[i3];
            SimpleDataset allExcept = getAllExcept(nFoldingClassDistributionInvariant, i3);
            logger.info("start testing on fold=" + i3);
            learningAlgorithm.learn(allExcept);
            test(learningAlgorithm.getPredictionFunction(), (Evaluator) arrayList.get(i3), simpleDataset2);
            learningAlgorithm.reset();
        }
        return arrayList;
    }

    private static SimpleDataset getAllExcept(Dataset[] datasetArr, int i) {
        SimpleDataset simpleDataset = new SimpleDataset();
        for (int i2 = 0; i2 < datasetArr.length; i2++) {
            if (i != i2) {
                simpleDataset.addExamples(datasetArr[i2]);
            }
        }
        return simpleDataset;
    }
}
