package it.uniroma2.sag.kelp.kernel;

import com.fasterxml.jackson.annotation.JsonIdentityInfo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.ObjectIdGenerators;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver;
import it.uniroma2.sag.kelp.data.example.Example;
import it.uniroma2.sag.kelp.kernel.cache.KernelCache;
import it.uniroma2.sag.kelp.kernel.cache.SquaredNormCache;
import it.uniroma2.sag.kelp.utils.FileUtils;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;

@JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, include = JsonTypeInfo.As.PROPERTY, property = "kernelType")
@JsonTypeIdResolver(KernelTypeResolver.class)
@JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "kernelID")
/* loaded from: input_file:it/uniroma2/sag/kelp/kernel/Kernel.class */
public abstract class Kernel {

    @JsonIgnore
    private long numberOfKernelComputations = 0;

    @JsonIgnore
    private long numberOfHits = 0;

    @JsonIgnore
    private SquaredNormCache normCache = null;

    @JsonIgnore
    private KernelCache cache = null;

    public final float innerProduct(Example example, Example example2) {
        float kernelComputation;
        if (example.getId() == example2.getId() && this.normCache != null) {
            return squaredNorm(example);
        }
        this.numberOfKernelComputations++;
        Example example3 = example;
        Example example4 = example2;
        if (example.getId() > example2.getId()) {
            example3 = example2;
            example4 = example;
        }
        if (this.cache != null) {
            Float kernelValue = this.cache.getKernelValue(example3, example4);
            if (kernelValue != null) {
                this.numberOfHits++;
                kernelComputation = kernelValue.floatValue();
            } else {
                kernelComputation = kernelComputation(example3, example4);
                this.cache.setKernelValue(example3, example4, kernelComputation);
            }
        } else {
            kernelComputation = kernelComputation(example3, example4);
        }
        return kernelComputation;
    }

    public void setSquaredNormCache(SquaredNormCache squaredNormCache) {
        this.normCache = squaredNormCache;
    }

    public SquaredNormCache getSquaredNormCache() {
        return this.normCache;
    }

    public void setKernelCache(KernelCache kernelCache) {
        this.cache = kernelCache;
    }

    public KernelCache getKernelCache() {
        return this.cache;
    }

    protected abstract float kernelComputation(Example example, Example example2);

    public float squaredNorm(Example example) {
        if (this.normCache == null) {
            return innerProduct(example, example);
        }
        this.numberOfKernelComputations++;
        Float squaredNorm = this.normCache.getSquaredNorm(example);
        if (squaredNorm != null) {
            this.numberOfHits++;
            return squaredNorm.floatValue();
        }
        float kernelComputation = kernelComputation(example, example);
        this.normCache.setSquaredNormValue(example, kernelComputation);
        return kernelComputation;
    }

    public float squaredNormOfTheDifference(Example example, Example example2) {
        return (squaredNorm(example) + squaredNorm(example2)) - (2.0f * innerProduct(example, example2));
    }

    public void disableCache() {
        this.cache = null;
        this.normCache = null;
    }

    @JsonIgnore
    public long getKernelComputations() {
        return this.numberOfKernelComputations;
    }

    @JsonIgnore
    public long getNumberOfHits() {
        return this.numberOfHits;
    }

    @JsonIgnore
    public long getNumberOfMisses() {
        return this.numberOfKernelComputations - this.numberOfHits;
    }

    public void reset() {
        this.numberOfHits = 0L;
        this.numberOfKernelComputations = 0L;
    }

    public static void save(Kernel kernel, String str) throws FileNotFoundException, IOException {
        ObjectMapper objectMapper = new ObjectMapper();
        OutputStreamWriter outputStreamWriter = new OutputStreamWriter(FileUtils.createOutputStream(str), "utf8");
        objectMapper.writeValue(outputStreamWriter, kernel);
        outputStreamWriter.close();
    }

    public static Kernel load(String str) throws FileNotFoundException, IOException {
        ObjectMapper objectMapper = new ObjectMapper();
        InputStreamReader inputStreamReader = new InputStreamReader(FileUtils.createInputStream(str), "utf8");
        Kernel kernel = (Kernel) objectMapper.readValue(inputStreamReader, Kernel.class);
        inputStreamReader.close();
        return kernel;
    }
}
