package cc.mallet.fst.semi_supervised.pr.constraints;

import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/constraints/OneLabelL2IndPRConstraints.class */
public class OneLabelL2IndPRConstraints implements PRConstraint {
    protected boolean normalized;
    protected int numDimensions;
    protected TIntObjectHashMap<OneLabelL2IndPRConstraint> constraints;
    protected StateLabelMap map;
    protected TIntArrayList cache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/constraints/OneLabelL2IndPRConstraints$OneLabelL2IndPRConstraint.class */
    protected class OneLabelL2IndPRConstraint {
        protected int index = 0;
        protected double count = 0.0d;
        protected ArrayList<Integer> labels = new ArrayList<>();
        protected ArrayList<Integer> paramIndices = new ArrayList<>();
        protected ArrayList<Double> targets = new ArrayList<>();
        protected ArrayList<Double> weights = new ArrayList<>();
        protected HashMap<Integer, Integer> labelMap = new HashMap<>();
        protected double[] expectation;

        public OneLabelL2IndPRConstraint() {
        }

        public OneLabelL2IndPRConstraint copy() {
            OneLabelL2IndPRConstraint oneLabelL2IndPRConstraint = new OneLabelL2IndPRConstraint();
            oneLabelL2IndPRConstraint.index = this.index;
            oneLabelL2IndPRConstraint.count = this.count;
            oneLabelL2IndPRConstraint.labels = this.labels;
            oneLabelL2IndPRConstraint.paramIndices = this.paramIndices;
            oneLabelL2IndPRConstraint.targets = this.targets;
            oneLabelL2IndPRConstraint.weights = this.weights;
            oneLabelL2IndPRConstraint.labelMap = this.labelMap;
            oneLabelL2IndPRConstraint.expectation = new double[this.index];
            return oneLabelL2IndPRConstraint;
        }

        public void add(int i, double d, double d2, int i2) {
            this.targets.add(Double.valueOf(d));
            this.weights.add(Double.valueOf(d2));
            this.labels.add(Integer.valueOf(i));
            this.paramIndices.add(Integer.valueOf(i2));
            this.labelMap.put(Integer.valueOf(i), Integer.valueOf(this.index));
            this.index++;
        }

        public void zeroExpectation() {
            this.expectation = new double[this.labels.size()];
        }

        public void getExpectations(double[] dArr) {
            for (int i = 0; i < this.paramIndices.size(); i++) {
                dArr[this.paramIndices.get(i).intValue()] = this.expectation[i];
            }
        }

        public void addExpectations(double[] dArr) {
            for (int i = 0; i < this.paramIndices.size(); i++) {
                double[] dArr2 = this.expectation;
                int i2 = i;
                dArr2[i2] = dArr2[i2] + dArr[this.paramIndices.get(i).intValue()];
            }
        }

        public void incrementExpectation(int i, double d) {
            if (this.labelMap.containsKey(Integer.valueOf(i))) {
                int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
                double[] dArr = this.expectation;
                dArr[intValue] = dArr[intValue] + d;
            }
        }

        public double getScore(int i, double[] dArr) {
            if (!this.labelMap.containsKey(Integer.valueOf(i))) {
                return 0.0d;
            }
            int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
            return OneLabelL2IndPRConstraints.this.normalized ? dArr[this.paramIndices.get(intValue).intValue()] / this.count : dArr[this.paramIndices.get(intValue).intValue()];
        }

        public double getProjectionValueContrib(double[] dArr) {
            double d = 0.0d;
            for (int i = 0; i < this.paramIndices.size(); i++) {
                double d2 = dArr[this.paramIndices.get(i).intValue()];
                d += (this.targets.get(i).doubleValue() * d2) - ((d2 * d2) / (2.0d * this.weights.get(i).doubleValue()));
            }
            return d;
        }

        public double getCompleteValueContrib() {
            double d;
            double doubleValue;
            double pow;
            double d2 = 0.0d;
            for (int i = 0; i < this.paramIndices.size(); i++) {
                if (OneLabelL2IndPRConstraints.this.normalized) {
                    d = d2;
                    doubleValue = this.weights.get(i).doubleValue();
                    pow = Math.pow(this.targets.get(i).doubleValue() - (this.expectation[i] / this.count), 2.0d);
                } else {
                    d = d2;
                    doubleValue = this.weights.get(i).doubleValue();
                    pow = Math.pow(this.targets.get(i).doubleValue() - this.expectation[i], 2.0d);
                }
                d2 = d + ((doubleValue * pow) / 2.0d);
            }
            return d2;
        }

        public void getGradient(double[] dArr, double[] dArr2) {
            for (int i = 0; i < this.paramIndices.size(); i++) {
                int intValue = this.paramIndices.get(i).intValue();
                if (OneLabelL2IndPRConstraints.this.normalized) {
                    dArr2[intValue] = dArr2[intValue] + ((this.targets.get(i).doubleValue() - (this.expectation[i] / this.count)) - (dArr[intValue] / this.weights.get(i).doubleValue()));
                } else {
                    dArr2[intValue] = dArr2[intValue] + ((this.targets.get(i).doubleValue() - this.expectation[i]) - (dArr[intValue] / this.weights.get(i).doubleValue()));
                }
            }
        }

        public int getNumConstrainedLabels() {
            return this.index;
        }
    }

    public OneLabelL2IndPRConstraints(boolean z) {
        this.normalized = z;
        this.numDimensions = 0;
        this.constraints = new TIntObjectHashMap<>();
        this.map = null;
        this.cache = new TIntArrayList();
    }

    protected OneLabelL2IndPRConstraints(TIntObjectHashMap<OneLabelL2IndPRConstraint> tIntObjectHashMap, StateLabelMap stateLabelMap, boolean z) {
        this.normalized = z;
        this.numDimensions = 0;
        this.constraints = new TIntObjectHashMap<>();
        for (int i : tIntObjectHashMap.keys()) {
            this.constraints.put(i, tIntObjectHashMap.get(i).copy());
            this.numDimensions += tIntObjectHashMap.get(i).getNumConstrainedLabels();
        }
        this.map = stateLabelMap;
        this.cache = new TIntArrayList();
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public PRConstraint copy() {
        return new OneLabelL2IndPRConstraints(this.constraints, this.map, this.normalized);
    }

    public void addConstraint(int i, int i2, double d, double d2) {
        if (!this.constraints.containsKey(i)) {
            this.constraints.put(i, new OneLabelL2IndPRConstraint());
        }
        this.constraints.get(i).add(i2, d, d2, this.numDimensions);
        this.numDimensions++;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public int numDimensions() {
        return this.numDimensions;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void setStateLabelMap(StateLabelMap stateLabelMap) {
        this.map = stateLabelMap;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void preProcess(FeatureVector featureVector) {
        this.cache.resetQuick();
        for (int i = 0; i < featureVector.numLocations(); i++) {
            int indexAtLocation = featureVector.indexAtLocation(i);
            if (this.constraints.containsKey(indexAtLocation)) {
                this.cache.add(indexAtLocation);
            }
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public BitSet preProcess(InstanceList instanceList) {
        int i = 0;
        BitSet bitSet = new BitSet(instanceList.size());
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) it.next().getData();
            for (int i2 = 0; i2 < featureVectorSequence.size(); i2++) {
                FeatureVector featureVector = featureVectorSequence.get(i2);
                for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                    int indexAtLocation = featureVector.indexAtLocation(i3);
                    if (this.constraints.containsKey(indexAtLocation)) {
                        this.constraints.get(indexAtLocation).count += 1.0d;
                        bitSet.set(i);
                    }
                }
            }
            i++;
        }
        return bitSet;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public double getScore(FeatureVector featureVector, int i, int i2, int i3, double[] dArr) {
        double d = 0.0d;
        int labelIndex = this.map.getLabelIndex(i3);
        for (int i4 = 0; i4 < this.cache.size(); i4++) {
            d += this.constraints.get(this.cache.getQuick(i4)).getScore(labelIndex, dArr);
        }
        return d;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void incrementExpectations(FeatureVector featureVector, int i, int i2, int i3, double d) {
        int labelIndex = this.map.getLabelIndex(i3);
        for (int i4 = 0; i4 < this.cache.size(); i4++) {
            this.constraints.get(this.cache.getQuick(i4)).incrementExpectation(labelIndex, d);
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void getExpectations(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != numDimensions()) {
            throw new AssertionError(dArr.length + " " + numDimensions());
        }
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).getExpectations(dArr);
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void addExpectations(double[] dArr) {
        if (!$assertionsDisabled && dArr.length != numDimensions()) {
            throw new AssertionError();
        }
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).addExpectations(dArr);
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void zeroExpectations() {
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).zeroExpectation();
        }
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public double getAuxiliaryValueContribution(double[] dArr) {
        double d = 0.0d;
        for (int i : this.constraints.keys()) {
            d += this.constraints.get(i).getProjectionValueContrib(dArr);
        }
        return d;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public double getCompleteValueContribution(double[] dArr) {
        double d = 0.0d;
        for (int i : this.constraints.keys()) {
            d += this.constraints.get(i).getCompleteValueContrib();
        }
        return d;
    }

    @Override // cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint
    public void getGradient(double[] dArr, double[] dArr2) {
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).getGradient(dArr, dArr2);
        }
    }

    static {
        $assertionsDisabled = !OneLabelL2IndPRConstraints.class.desiredAssertionStatus();
    }
}
