package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.xpath.XPath;

/* loaded from: input_file:edu/stanford/nlp/optimization/AbstractStochasticCachingDiffFunction.class */
public abstract class AbstractStochasticCachingDiffFunction extends AbstractCachingDiffFunction {
    public boolean hasNewVals = true;
    public boolean recalculatePrevBatch = false;
    public boolean returnPreviousValues = false;
    protected int lastBatchSize = 0;
    protected int[] lastBatch = null;
    protected int[] thisBatch = null;
    protected double[] lastXBatch = null;
    protected double[] lastVBatch = null;
    protected int lastElement = 0;
    protected double[] HdotV = null;
    protected double[] gradPerturbed = null;
    protected double[] xPerturbed = null;
    protected int curElement = 0;
    protected List<Integer> allIndices = null;
    protected Random randGenerator = new Random(1);
    protected boolean scaleUp = false;
    private int[] shuffledArray = null;
    public StochasticCalculateMethods method = StochasticCalculateMethods.ExternalFiniteDifference;
    public SamplingMethod sampleMethod = SamplingMethod.RandomWithoutReplacement;
    protected double finiteDifferenceStepSize = 1.0E-4d;

    /* loaded from: input_file:edu/stanford/nlp/optimization/AbstractStochasticCachingDiffFunction$SamplingMethod.class */
    public enum SamplingMethod {
        NoneSpecified,
        RandomWithReplacement,
        RandomWithoutReplacement,
        Ordered,
        Shuffled
    }

    public void incrementRandom(int i) {
        System.err.println("incrementing random " + i + " times.");
        for (int i2 = 0; i2 < i; i2++) {
            this.randGenerator.nextInt(dataDimension());
        }
    }

    public void scaleUp(boolean z) {
        this.scaleUp = z;
    }

    public abstract void calculateStochastic(double[] dArr, double[] dArr2, int[] iArr);

    public abstract int dataDimension();

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    protected void clearCache() {
        if (this.lastX != null) {
            this.lastX[0] = Double.NaN;
        }
        if (this.lastXBatch != null) {
            this.lastXBatch[0] = Double.NaN;
        }
        if (this.lastVBatch != null) {
            this.lastVBatch[0] = Double.NaN;
        }
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction, edu.stanford.nlp.optimization.HasInitial
    public double[] initial() {
        double[] dArr = new double[domainDimension()];
        Arrays.fill(dArr, XPath.MATCH_SCORE_QNAME);
        return dArr;
    }

    public void decrementBatch(int i) {
        this.curElement -= i;
        if (this.curElement < 0) {
            this.curElement = 0;
        }
    }

    public void incrementBatch(int i) {
        this.curElement += i;
        this.hasNewVals = false;
        this.recalculatePrevBatch = false;
        this.returnPreviousValues = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void getBatch(int i) {
        if (this.thisBatch == null || this.thisBatch.length != i) {
            this.thisBatch = new int[i];
        }
        if (this.sampleMethod == SamplingMethod.Shuffled) {
            if (this.shuffledArray == null) {
                this.shuffledArray = ArrayMath.range(0, dataDimension());
            }
            for (int i2 = 0; i2 < i; i2++) {
                this.thisBatch[i2] = this.shuffledArray[(this.curElement + i2) % dataDimension()];
            }
            this.curElement = (this.curElement + i) % dataDimension();
            return;
        }
        if (this.sampleMethod == SamplingMethod.RandomWithReplacement) {
            for (int i3 = 0; i3 < i; i3++) {
                this.thisBatch[i3] = this.randGenerator.nextInt(dataDimension());
            }
            return;
        }
        if (this.sampleMethod == SamplingMethod.Ordered) {
            for (int i4 = 0; i4 < i; i4++) {
                this.thisBatch[i4] = (this.curElement + i4) % dataDimension();
            }
            this.curElement = (this.curElement + i) % dataDimension();
            return;
        }
        if (this.sampleMethod != SamplingMethod.RandomWithoutReplacement) {
            throw new IllegalStateException("NO SAMPLING METHOD SELECTED");
        }
        if (this.allIndices == null || this.allIndices.size() != dataDimension()) {
            this.allIndices = new ArrayList();
            for (int i5 = 0; i5 < dataDimension(); i5++) {
                this.allIndices.add(Integer.valueOf(i5));
            }
            Collections.shuffle(this.allIndices, this.randGenerator);
        }
        for (int i6 = 0; i6 < i; i6++) {
            this.thisBatch[i6] = this.allIndices.get((this.curElement + i6) % this.allIndices.size()).intValue();
        }
        if (this.curElement + i > dataDimension()) {
            Collections.shuffle(Collections.singletonList(this.allIndices), this.randGenerator);
        }
        this.curElement = (this.curElement + i) % this.allIndices.size();
    }

    void stochasticEnsure(double[] dArr, double[] dArr2, int i) {
        if (this.lastXBatch == null) {
            this.lastXBatch = new double[domainDimension()];
            System.err.println("Setting previous position (x).");
        }
        if (this.lastVBatch == null) {
            this.lastVBatch = new double[domainDimension()];
            System.err.println("Setting previous gain (v)");
        }
        if (this.derivative == null) {
            this.derivative = new double[domainDimension()];
            System.err.println("Setting Derivative.");
        }
        if (this.HdotV == null) {
            this.HdotV = new double[domainDimension()];
            System.err.println("Setting HdotV.");
        }
        if (this.lastBatch == null) {
            this.lastBatch = new int[i];
            System.err.println("Setting last batch");
        }
        if (this.recalculatePrevBatch && i == this.lastBatch.length) {
            this.thisBatch = this.lastBatch;
        } else {
            if (this.returnPreviousValues) {
                this.returnPreviousValues = false;
                return;
            }
            if (!this.hasNewVals && this.lastElement != this.curElement && this.lastBatchSize == i && Arrays.equals(dArr, this.lastXBatch) && Arrays.equals(dArr2, this.lastVBatch) && Arrays.equals(this.thisBatch, this.lastBatch)) {
                return;
            } else {
                getBatch(i);
            }
        }
        copy(this.lastXBatch, dArr);
        if (this.lastBatch.length != i) {
            this.lastBatch = new int[i];
        }
        System.arraycopy(this.thisBatch, 0, this.lastBatch, 0, this.thisBatch.length);
        if (dArr2 != null) {
            copy(this.lastVBatch, dArr2);
        }
        this.lastBatchSize = i;
        calculateStochastic(dArr, dArr2, this.thisBatch);
        if (this.scaleUp) {
            double dataDimension = dataDimension() / i;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                this.derivative[i2] = this.derivative[i2] * dataDimension;
            }
            this.value = dataDimension * this.value;
        }
        incrementBatch(i);
        this.lastElement = this.curElement;
    }

    public double valueAt(double[] dArr, int i) {
        stochasticEnsure(dArr, null, i);
        return this.value;
    }

    public double[] derivativeAt(double[] dArr, int i) {
        stochasticEnsure(dArr, null, i);
        return this.derivative;
    }

    public double valueAt(double[] dArr, double[] dArr2, int i) {
        stochasticEnsure(dArr, dArr2, i);
        return this.value;
    }

    public double[] derivativeAt(double[] dArr, double[] dArr2, int i) {
        stochasticEnsure(dArr, dArr2, i);
        return this.derivative;
    }

    private void getHdotVFiniteDifference(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = this.finiteDifferenceStepSize;
        double d2 = 1.0d / d;
        if (this.gradPerturbed == null) {
            this.gradPerturbed = new double[dArr.length];
            System.out.println("Setting approximate gradient.");
        }
        if (this.xPerturbed == null) {
            this.xPerturbed = new double[dArr.length];
            System.out.println("Setting perturbed.");
        }
        if (this.HdotV == null) {
            this.HdotV = new double[dArr.length];
            System.out.println("Setting H dot V.");
        }
        for (int i = 0; i < dArr.length; i++) {
            this.xPerturbed[i] = dArr[i] + (d * dArr2[i]);
        }
        double d3 = this.value;
        this.recalculatePrevBatch = true;
        calculateStochastic(this.xPerturbed, null, this.thisBatch);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            this.HdotV[i2] = d2 * (this.derivative[i2] - dArr3[i2]);
        }
        System.arraycopy(dArr3, 0, this.derivative, 0, this.derivative.length);
        this.value = d3;
        this.hasNewVals = false;
        this.recalculatePrevBatch = false;
        this.returnPreviousValues = false;
    }

    public double[] HdotVAt(double[] dArr, double[] dArr2, int i) {
        if (this.method == StochasticCalculateMethods.ExternalFiniteDifference) {
            System.err.println("Attempt to use ExternalFiniteDifference without passing currentDerivative");
            throw new RuntimeException();
        }
        stochasticEnsure(dArr, dArr2, i);
        return this.HdotV;
    }

    public double[] HdotVAt(double[] dArr, double[] dArr2, double[] dArr3, int i) {
        if (this.method == StochasticCalculateMethods.ExternalFiniteDifference) {
            getHdotVFiniteDifference(dArr, dArr2, dArr3);
        } else {
            stochasticEnsure(dArr, dArr2, i);
        }
        return this.HdotV;
    }

    public double[] HdotVAt(double[] dArr, double[] dArr2) {
        if (this.method == StochasticCalculateMethods.ExternalFiniteDifference) {
            System.err.println("Attempt to use ExternalFiniteDifference without passing currentDerivative");
            throw new RuntimeException();
        }
        stochasticEnsure(dArr, dArr2, dataDimension());
        decrementBatch(dataDimension());
        return this.HdotV;
    }

    public double[] lastDerivative() {
        return this.derivative;
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    public double lastValue() {
        return this.value;
    }
}
