package com.google.research.reflection.predictor;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: classes.dex */
public class ReflectionPA extends MulticlassPA {
    private static final int MAX_NUM_UNIQUE_TARGETS = 200;

    ReflectionPA() {
    }

    public ReflectionPA(int i, int i2, float f) {
        super(i, i2, f);
    }

    void PrepareLambdas(int i, float[][] fArr, HashMap<Integer, Float> hashMap) {
        hashMap.clear();
        HashMap hashMap2 = new HashMap();
        float Score = Score(fArr[i], parameters().get(i));
        for (int i2 = 0; i2 < num_classes(); i2++) {
            if (i2 != i) {
                float Score2 = (1.0f + Score(fArr[i2], parameters().get(i2))) - Score;
                if (Score2 > 0.0d) {
                    hashMap2.put(Integer.valueOf(i2), Float.valueOf(Score2));
                    hashMap.put(Integer.valueOf(i2), Float.valueOf((float) Math.random()));
                }
            }
        }
        if (hashMap2.isEmpty()) {
            return;
        }
        int size = hashMap2.size();
        ArrayList arrayList = new ArrayList(size);
        arrayList.addAll(hashMap2.keySet());
        HashMap hashMap3 = new HashMap();
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            hashMap3.put((Integer) arrayList.get(i3), Integer.valueOf(i3));
        }
        float[][] fArr2 = (float[][]) Array.newInstance((Class<?>) Float.TYPE, size, size);
        float L2NormSquare = L2NormSquare(fArr[i]);
        for (int i4 = 0; i4 < size; i4++) {
            for (int i5 = i4; i5 < size; i5++) {
                if (i5 == i4) {
                    fArr2[i4][i5] = L2NormSquare(fArr[((Integer) arrayList.get(i4)).intValue()]) + L2NormSquare;
                } else {
                    fArr2[i4][i5] = L2NormSquare;
                    fArr2[i5][i4] = L2NormSquare;
                }
            }
        }
        for (int i6 = 0; i6 < 1000; i6++) {
            for (int i7 = 0; i7 < size; i7++) {
                float f = 0.0f;
                for (Map.Entry<Integer, Float> entry : hashMap.entrySet()) {
                    f += fArr2[i7][((Integer) hashMap3.get(entry.getKey())).intValue()] * entry.getValue().floatValue();
                }
                for (Map.Entry<Integer, Float> entry2 : hashMap.entrySet()) {
                    float floatValue = entry2.getValue().floatValue() + (fArr2[i7][((Integer) hashMap3.get(entry2.getKey())).intValue()] * 0.01f * (((Float) hashMap2.get(entry2.getKey())).floatValue() - f));
                    if (floatValue < 0.0f) {
                        floatValue = 0.0f;
                    }
                    entry2.setValue(Float.valueOf(floatValue));
                }
            }
        }
    }

    @Override // com.google.research.reflection.predictor.MulticlassPA
    public float TrainOneExample(float[][] fArr, int i) {
        CHECK_GE(i, 0);
        CHECK_LT(i, num_classes());
        HashMap<Integer, Float> hashMap = new HashMap<>();
        PrepareLambdas(i, fArr, hashMap);
        if (hashMap.isEmpty()) {
            return 0.0f;
        }
        float f = 0.0f;
        Iterator<Float> it = hashMap.values().iterator();
        while (it.hasNext()) {
            f += it.next().floatValue();
        }
        float aggressiveness = f > aggressiveness() ? aggressiveness() / f : 1.0f;
        float f2 = 0.0f;
        for (Map.Entry<Integer, Float> entry : hashMap.entrySet()) {
            int intValue = entry.getKey().intValue();
            float floatValue = entry.getValue().floatValue() * aggressiveness;
            f2 += floatValue;
            for (int i2 = 0; i2 < fArr[i].length; i2++) {
                parameters().get(intValue).set(i2, Float.valueOf(parameters().get(intValue).get(i2).floatValue() - (fArr[intValue][i2] * floatValue)));
            }
        }
        for (int i3 = 0; i3 < fArr[i].length; i3++) {
            parameters().get(i).set(i3, Float.valueOf(parameters().get(i).get(i3).floatValue() + (fArr[i][i3] * f2)));
        }
        return 0.0f;
    }
}
