package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.xpath.XPath;

/* loaded from: input_file:edu/stanford/nlp/stats/TwoDimensionalCounter.class */
public class TwoDimensionalCounter<K1, K2> implements TwoDimensionalCounterInterface<K1, K2>, Serializable {
    private static final long serialVersionUID = 1;
    private Map<K1, ClassicCounter<K2>> map;
    private double total;
    private MapFactory<K1, ClassicCounter<K2>> outerMF;
    private MapFactory<K2, MutableDouble> innerMF;
    private double defaultValue;

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void defaultReturnValue(double d) {
        this.defaultValue = d;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public double defaultReturnValue() {
        return this.defaultValue;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (obj instanceof TwoDimensionalCounter) {
            return ((TwoDimensionalCounter) obj).map.equals(this.map);
        }
        return false;
    }

    public int hashCode() {
        return this.map.hashCode() + 17;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public ClassicCounter<K2> getCounter(K1 k1) {
        ClassicCounter<K2> classicCounter = this.map.get(k1);
        if (classicCounter == null) {
            classicCounter = new ClassicCounter<>(this.innerMF);
            classicCounter.setDefaultReturnValue(this.defaultValue);
            this.map.put(k1, classicCounter);
        }
        return classicCounter;
    }

    public Set<Map.Entry<K1, ClassicCounter<K2>>> entrySet() {
        return this.map.entrySet();
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public int size() {
        int i = 0;
        Iterator<K1> it = firstKeySet().iterator();
        while (it.hasNext()) {
            i += this.map.get(it.next()).size();
        }
        return i;
    }

    public int sizeOuterMap() {
        return this.map.size();
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public boolean containsKey(K1 k1, K2 k2) {
        if (this.map.containsKey(k1)) {
            return this.map.get(k1).containsKey(k2);
        }
        return false;
    }

    public boolean containsFirstKey(K1 k1) {
        return this.map.containsKey(k1);
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void incrementCount(K1 k1, K2 k2) {
        incrementCount(k1, k2, 1.0d);
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void incrementCount(K1 k1, K2 k2, double d) {
        getCounter((TwoDimensionalCounter<K1, K2>) k1).incrementCount(k2, d);
        this.total += d;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void decrementCount(K1 k1, K2 k2) {
        incrementCount(k1, k2, -1.0d);
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void decrementCount(K1 k1, K2 k2, double d) {
        incrementCount(k1, k2, -d);
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void setCount(K1 k1, K2 k2, double d) {
        ClassicCounter<K2> counter = getCounter((TwoDimensionalCounter<K1, K2>) k1);
        this.total -= getCount(k1, k2);
        counter.setCount(k2, d);
        this.total += d;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public double remove(K1 k1, K2 k2) {
        ClassicCounter<K2> counter = getCounter((TwoDimensionalCounter<K1, K2>) k1);
        double count = getCount(k1, k2);
        this.total -= count;
        counter.remove(k2);
        if (counter.size() == 0) {
            this.map.remove(k1);
        }
        return count;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public double getCount(K1 k1, K2 k2) {
        ClassicCounter<K2> counter = getCounter((TwoDimensionalCounter<K1, K2>) k1);
        return (counter.totalCount() != XPath.MATCH_SCORE_QNAME || counter.keySet().contains(k2)) ? counter.getCount(k2) : defaultReturnValue();
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public double totalCount() {
        return this.total;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public double totalCount(K1 k1) {
        return getCounter((TwoDimensionalCounter<K1, K2>) k1).totalCount();
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public Set<K1> firstKeySet() {
        return this.map.keySet();
    }

    public ClassicCounter<K2> setCounter(K1 k1, Counter<K2> counter) {
        ClassicCounter<K2> counter2 = getCounter((TwoDimensionalCounter<K1, K2>) k1);
        this.total -= counter2.totalCount();
        if (counter instanceof ClassicCounter) {
            this.map.put(k1, (ClassicCounter) counter);
        } else {
            this.map.put(k1, new ClassicCounter<>(counter));
        }
        this.total += counter.totalCount();
        return counter2;
    }

    public static <K1, K2> TwoDimensionalCounter<K2, K1> reverseIndexOrder(TwoDimensionalCounter<K1, K2> twoDimensionalCounter) {
        TwoDimensionalCounter<K2, K1> twoDimensionalCounter2 = new TwoDimensionalCounter<>(((TwoDimensionalCounter) twoDimensionalCounter).outerMF, ((TwoDimensionalCounter) twoDimensionalCounter).innerMF);
        for (K1 k1 : twoDimensionalCounter.firstKeySet()) {
            ClassicCounter<K2> counter = twoDimensionalCounter.getCounter((TwoDimensionalCounter<K1, K2>) k1);
            for (K2 k2 : counter.keySet()) {
                twoDimensionalCounter2.setCount(k2, k1, counter.getCount(k2));
            }
        }
        return twoDimensionalCounter2;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (K1 k1 : this.map.keySet()) {
            ClassicCounter<K2> counter = getCounter((TwoDimensionalCounter<K1, K2>) k1);
            for (K2 k2 : counter.keySet()) {
                sb.append(k1).append(LinearClassifier.TEXT_SERIALIZATION_DELIMITER).append(k2).append(LinearClassifier.TEXT_SERIALIZATION_DELIMITER).append(counter.getCount(k2)).append("\n");
            }
        }
        return sb.toString();
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public String toMatrixString(int i) {
        ArrayList arrayList = new ArrayList(firstKeySet());
        ArrayList arrayList2 = new ArrayList(secondKeySet());
        Collections.sort(arrayList);
        Collections.sort(arrayList2);
        return ArrayMath.toString(toMatrix(arrayList, arrayList2), i, arrayList.toArray(), arrayList2.toArray(), (NumberFormat) new DecimalFormat(), true);
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public double[][] toMatrix(List<K1> list, List<K2> list2) {
        double[][] dArr = new double[list.size()][list2.size()];
        for (int i = 0; i < list.size(); i++) {
            for (int i2 = 0; i2 < list2.size(); i2++) {
                dArr[i][i2] = getCount(list.get(i), list2.get(i2));
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public String toCSVString(NumberFormat numberFormat) {
        ArrayList arrayList = new ArrayList(firstKeySet());
        ArrayList arrayList2 = new ArrayList(secondKeySet());
        Collections.sort(arrayList);
        Collections.sort(arrayList2);
        StringBuilder sb = new StringBuilder();
        String[] strArr = new String[arrayList2.size() + 1];
        strArr[0] = "";
        for (int i = 0; i < arrayList2.size(); i++) {
            strArr[i + 1] = arrayList2.get(i).toString();
        }
        sb.append(StringUtils.toCSVString(strArr)).append("\n");
        for (Object obj : arrayList) {
            String[] strArr2 = new String[arrayList2.size() + 1];
            strArr2[0] = obj.toString();
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                strArr2[i2 + 1] = numberFormat.format(getCount(obj, arrayList2.get(i2)));
            }
            sb.append(StringUtils.toCSVString(strArr2)).append("\n");
        }
        return sb.toString();
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public Set<K2> secondKeySet() {
        Set<K2> newHashSet = Generics.newHashSet();
        Iterator<K1> it = firstKeySet().iterator();
        while (it.hasNext()) {
            Iterator<K2> it2 = getCounter((TwoDimensionalCounter<K1, K2>) it.next()).keySet().iterator();
            while (it2.hasNext()) {
                newHashSet.add(it2.next());
            }
        }
        return newHashSet;
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public boolean isEmpty() {
        return this.map.isEmpty();
    }

    public ClassicCounter<Pair<K1, K2>> flatten() {
        ClassicCounter<Pair<K1, K2>> classicCounter = new ClassicCounter<>();
        classicCounter.setDefaultReturnValue(this.defaultValue);
        for (K1 k1 : firstKeySet()) {
            ClassicCounter<K2> counter = getCounter((TwoDimensionalCounter<K1, K2>) k1);
            for (K2 k2 : counter.keySet()) {
                classicCounter.setCount(new Pair<>(k1, k2), counter.getCount(k2));
            }
        }
        return classicCounter;
    }

    public void addAll(TwoDimensionalCounterInterface<K1, K2> twoDimensionalCounterInterface) {
        for (K1 k1 : twoDimensionalCounterInterface.firstKeySet()) {
            Counter<K2> counter = twoDimensionalCounterInterface.getCounter(k1);
            Counters.addInPlace(getCounter((TwoDimensionalCounter<K1, K2>) k1), counter);
            this.total += counter.totalCount();
        }
    }

    public void addAll(K1 k1, Counter<K2> counter) {
        Counters.addInPlace(getCounter((TwoDimensionalCounter<K1, K2>) k1), counter);
        this.total += counter.totalCount();
    }

    public void subtractAll(K1 k1, Counter<K2> counter) {
        Counters.subtractInPlace(getCounter((TwoDimensionalCounter<K1, K2>) k1), counter);
        this.total -= counter.totalCount();
    }

    public void subtractAll(TwoDimensionalCounterInterface<K1, K2> twoDimensionalCounterInterface, boolean z) {
        for (K1 k1 : twoDimensionalCounterInterface.firstKeySet()) {
            Counter<K2> counter = twoDimensionalCounterInterface.getCounter(k1);
            ClassicCounter<K2> counter2 = getCounter((TwoDimensionalCounter<K1, K2>) k1);
            Counters.subtractInPlace(counter2, counter);
            if (z) {
                Counters.retainNonZeros(counter2);
            }
            this.total -= counter.totalCount();
        }
    }

    public Counter<K1> sumInnerCounter() {
        ClassicCounter classicCounter = new ClassicCounter();
        for (K1 k1 : firstKeySet()) {
            classicCounter.incrementCount(k1, getCounter((TwoDimensionalCounter<K1, K2>) k1).totalCount());
        }
        return classicCounter;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void removeZeroCounts() {
        for (Object obj : Generics.newHashSet(firstKeySet())) {
            ClassicCounter counter = getCounter((TwoDimensionalCounter<K1, K2>) obj);
            Counters.retainNonZeros(counter);
            if (counter.size() == 0) {
                this.map.remove(obj);
            }
        }
    }

    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public void remove(K1 k1) {
        ClassicCounter<K2> classicCounter = this.map.get(k1);
        if (classicCounter != null) {
            this.total -= classicCounter.totalCount();
        }
        this.map.remove(k1);
    }

    public void clear() {
        this.map.clear();
        this.total = XPath.MATCH_SCORE_QNAME;
        this.defaultValue = XPath.MATCH_SCORE_QNAME;
    }

    public void clean() {
        for (Object obj : Generics.newHashSet(this.map.keySet())) {
            ClassicCounter classicCounter = this.map.get(obj);
            for (Object obj2 : Generics.newHashSet(classicCounter.keySet())) {
                if (SloppyMath.isCloseTo(XPath.MATCH_SCORE_QNAME, classicCounter.getCount(obj2))) {
                    classicCounter.remove(obj2);
                }
            }
            if (classicCounter.keySet().isEmpty()) {
                this.map.remove(obj);
            }
        }
    }

    public MapFactory<K1, ClassicCounter<K2>> getOuterMapFactory() {
        return this.outerMF;
    }

    public MapFactory<K2, MutableDouble> getInnerMapFactory() {
        return this.innerMF;
    }

    public TwoDimensionalCounter() {
        this(MapFactory.hashMapFactory(), MapFactory.hashMapFactory());
    }

    public TwoDimensionalCounter(MapFactory<K1, ClassicCounter<K2>> mapFactory, MapFactory<K2, MutableDouble> mapFactory2) {
        this.defaultValue = XPath.MATCH_SCORE_QNAME;
        this.innerMF = mapFactory2;
        this.outerMF = mapFactory;
        this.map = mapFactory.newMap();
        this.total = XPath.MATCH_SCORE_QNAME;
    }

    public static <K1, K2> TwoDimensionalCounter<K1, K2> identityHashMapCounter() {
        return new TwoDimensionalCounter<>(MapFactory.identityHashMapFactory(), MapFactory.identityHashMapFactory());
    }

    public void recomputeTotal() {
        this.total = XPath.MATCH_SCORE_QNAME;
        Iterator<Map.Entry<K1, ClassicCounter<K2>>> it = this.map.entrySet().iterator();
        while (it.hasNext()) {
            this.total += it.next().getValue().totalCount();
        }
    }

    public static void main(String[] strArr) {
        TwoDimensionalCounter twoDimensionalCounter = new TwoDimensionalCounter();
        twoDimensionalCounter.setCount("a", "c", 1.0d);
        twoDimensionalCounter.setCount("b", "c", 1.0d);
        twoDimensionalCounter.setCount("a", "d", 1.0d);
        twoDimensionalCounter.setCount("a", "d", -1.0d);
        twoDimensionalCounter.setCount("b", "d", 1.0d);
        System.out.println(twoDimensionalCounter);
        twoDimensionalCounter.incrementCount("b", "d", 1.0d);
        System.out.println(twoDimensionalCounter);
        System.out.println(reverseIndexOrder(twoDimensionalCounter));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.stanford.nlp.stats.TwoDimensionalCounterInterface
    public /* bridge */ /* synthetic */ Counter getCounter(Object obj) {
        return getCounter((TwoDimensionalCounter<K1, K2>) obj);
    }
}
