package libai.nn.supervised;

import java.io.FileInputStream;
import java.io.ObjectInputStream;
import libai.common.Matrix;
import libai.common.functions.Function;
import libai.nn.NeuralNetwork;

/* loaded from: input_file:libai/nn/supervised/MLP.class */
public class MLP extends NeuralNetwork {
    private static final long serialVersionUID = 3155220303024711102L;
    public static final int STANDARD_BACKPROPAGATION = 0;
    public static final int MOMEMTUM_BACKPROPAGATION = 1;
    public static final int RESILENT_BACKPROPAGATION = 2;
    private final Matrix[] W;
    private final Matrix[] b;
    private final Matrix[] Y;
    private final Matrix[] d;
    private final Matrix[] u;
    private final Matrix[] Wt;
    private final Matrix[] Yt;
    private final Matrix[] M;
    private final int[] nperlayer;
    private final int layers;
    private final Function[] func;
    private double[] params;
    private int trainingType;

    public MLP(int[] iArr, Function[] functionArr) {
        this(iArr, functionArr, 0.0d);
        this.trainingType = 0;
    }

    public MLP(int[] iArr, Function[] functionArr, double d) {
        this.trainingType = 0;
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("beta should be positive and less than 1");
        }
        this.nperlayer = iArr;
        this.func = functionArr;
        this.params = new double[]{d};
        this.trainingType = d > 0.0d ? 1 : 0;
        this.layers = iArr.length;
        this.W = new Matrix[this.layers];
        this.b = new Matrix[this.layers];
        this.Y = new Matrix[this.layers];
        this.d = new Matrix[this.layers];
        this.u = new Matrix[this.layers];
        this.Wt = new Matrix[this.layers];
        this.Yt = new Matrix[this.layers];
        this.M = new Matrix[this.layers];
        init();
    }

    public void setTrainingType(int i, double... dArr) {
        this.trainingType = i;
        if (this.trainingType != 1) {
            if (this.trainingType == 2) {
            }
        } else {
            if (dArr.length < 1) {
                throw new IllegalArgumentException("Momemtum algorithm requires 1 parameter: beta");
            }
            this.params = dArr;
        }
    }

    private void init() {
        this.Yt[0] = new Matrix(1, this.nperlayer[0]);
        this.Y[0] = new Matrix(this.nperlayer[0], 1);
        for (int i = 1; i < this.layers; i++) {
            this.W[i] = new Matrix(this.nperlayer[i], this.nperlayer[i - 1]);
            this.Wt[i] = new Matrix(this.nperlayer[i - 1], this.nperlayer[i]);
            this.b[i] = new Matrix(this.nperlayer[i], 1);
            this.W[i].fill();
            this.b[i].fill();
            this.u[i] = new Matrix(this.W[i].getRows(), this.Y[i - 1].getColumns());
            this.Y[i] = new Matrix(this.u[i].getRows(), this.u[i].getColumns());
            this.Yt[i] = new Matrix(this.u[i].getColumns(), this.u[i].getRows());
            this.M[i] = new Matrix(this.u[i].getRows(), this.Y[i - 1].getRows());
        }
        this.d[this.layers - 1] = new Matrix(this.u[this.layers - 1].getRows(), 1);
        for (int i2 = this.layers - 2; i2 > 0; i2--) {
            this.d[i2] = new Matrix(this.u[i2].getRows(), 1);
        }
    }

    @Override // libai.nn.NeuralNetwork
    public void train(Matrix[] matrixArr, Matrix[] matrixArr2, double d, int i, int i2, int i3, double d2) {
        if (this.progress != null) {
            this.progress.setMaximum(0);
            this.progress.setMinimum(-i);
            this.progress.setValue(-i);
        }
        if (this.trainingType == 1) {
            momemtumBP(matrixArr, matrixArr2, d, i, i2, i3, d2);
        } else if (this.trainingType == 2) {
            resilentBP(matrixArr, matrixArr2, d, i, i2, i3, d2);
        } else {
            standardBP(matrixArr, matrixArr2, d, i, i2, i3, d2);
        }
        if (this.progress != null) {
            this.progress.setValue(1);
        }
    }

    private void standardBP(Matrix[] matrixArr, Matrix[] matrixArr2, double d, int i, int i2, int i3, double d2) {
        int[] iArr = new int[i3];
        double error = error(matrixArr, matrixArr2, i2, i3);
        Matrix matrix = new Matrix(matrixArr2[0].getRows(), matrixArr2[0].getColumns());
        for (int i4 = 0; i4 < i3; i4++) {
            iArr[i4] = i4;
        }
        while (error > d2) {
            int i5 = i;
            i--;
            if (i5 <= 0) {
                return;
            }
            shuffle(iArr);
            double d3 = 0.0d;
            for (int i6 = 0; i6 < i3; i6++) {
                simulate(matrixArr[iArr[i6] + i2]);
                matrixArr2[iArr[i6] + i2].subtract(this.Y[this.layers - 1], matrix);
                for (int i7 = 0; i7 < this.nperlayer[this.layers - 1]; i7++) {
                    d3 += matrix.position(i7, 0) * matrix.position(i7, 0);
                }
                for (int i8 = 0; i8 < this.u[this.layers - 1].getRows(); i8++) {
                    this.d[this.layers - 1].position(i8, 0, (-2.0d) * d * this.func[this.layers - 1].getDerivate().eval(this.u[this.layers - 1].position(i8, 0)) * matrix.position(i8, 0));
                }
                for (int i9 = this.layers - 2; i9 > 0; i9--) {
                    for (int i10 = 0; i10 < this.u[i9].getRows(); i10++) {
                        double d4 = 0.0d;
                        for (int i11 = 0; i11 < this.W[i9 + 1].getRows(); i11++) {
                            d4 += this.W[i9 + 1].position(i11, i10) * this.d[i9 + 1].position(i11, 0);
                        }
                        this.d[i9].position(i10, 0, d * d4 * this.func[i9].getDerivate().eval(this.u[i9].position(i10, 0)));
                    }
                }
                for (int i12 = 1; i12 < this.layers; i12++) {
                    this.Y[i12 - 1].transpose(this.Yt[i12 - 1]);
                    this.d[i12].multiply(this.Yt[i12 - 1], this.M[i12]);
                    this.W[i12].subtract(this.M[i12], this.W[i12]);
                    this.b[i12].subtract(this.d[i12], this.b[i12]);
                }
            }
            error = d3 / i3;
            if (this.plotter != null) {
                this.plotter.setError(i, error);
            }
            if (this.progress != null) {
                this.progress.setValue(-i);
            }
        }
    }

    private void momemtumBP(Matrix[] matrixArr, Matrix[] matrixArr2, double d, int i, int i2, int i3, double d2) {
        int[] iArr = new int[i3];
        double error = error(matrixArr, matrixArr2, i2, i3);
        Matrix matrix = new Matrix(matrixArr2[0].getRows(), matrixArr2[0].getColumns());
        double d3 = this.params[0];
        for (int i4 = 0; i4 < i3; i4++) {
            iArr[i4] = i4;
        }
        Matrix[] matrixArr3 = new Matrix[this.layers];
        Matrix[] matrixArr4 = new Matrix[this.layers];
        for (int i5 = 1; i5 < this.layers; i5++) {
            matrixArr3[i5] = new Matrix(this.nperlayer[i5], this.nperlayer[i5 - 1]);
            matrixArr4[i5] = new Matrix(this.nperlayer[i5], 1);
            this.W[i5].copy(matrixArr3[i5]);
            this.b[i5].copy(matrixArr4[i5]);
        }
        while (error > d2) {
            int i6 = i;
            i--;
            if (i6 <= 0) {
                return;
            }
            shuffle(iArr);
            double d4 = 0.0d;
            for (int i7 = 0; i7 < i3; i7++) {
                simulate(matrixArr[iArr[i7] + i2]);
                matrixArr2[iArr[i7] + i2].subtract(this.Y[this.layers - 1], matrix);
                for (int i8 = 0; i8 < this.nperlayer[this.layers - 1]; i8++) {
                    d4 += matrix.position(i8, 0) * matrix.position(i8, 0);
                }
                for (int i9 = 0; i9 < this.u[this.layers - 1].getRows(); i9++) {
                    this.d[this.layers - 1].position(i9, 0, (-2.0d) * d * this.func[this.layers - 1].getDerivate().eval(this.u[this.layers - 1].position(i9, 0)) * matrix.position(i9, 0));
                }
                for (int i10 = this.layers - 2; i10 > 0; i10--) {
                    for (int i11 = 0; i11 < this.u[i10].getRows(); i11++) {
                        double d5 = 0.0d;
                        for (int i12 = 0; i12 < this.W[i10 + 1].getRows(); i12++) {
                            d5 += this.W[i10 + 1].position(i12, i11) * this.d[i10 + 1].position(i12, 0);
                        }
                        this.d[i10].position(i11, 0, d * d5 * this.func[i10].getDerivate().eval(this.u[i10].position(i11, 0)));
                    }
                }
                for (int i13 = 1; i13 < this.layers; i13++) {
                    this.Y[i13 - 1].transpose(this.Yt[i13 - 1]);
                    Matrix matrix2 = new Matrix(this.d[i13].getRows(), this.Y[i13 - 1].getRows());
                    this.d[i13].multiply(1.0d - d3, this.d[i13]);
                    this.d[i13].multiply(this.Yt[i13 - 1], matrix2);
                    this.W[i13].subtractAndCopy(matrixArr3[i13], this.M[i13], matrixArr3[i13]);
                    this.M[i13].multiplyAndAdd(d3, this.W[i13], this.W[i13]);
                    this.W[i13].subtract(matrix2, this.W[i13]);
                    Matrix matrix3 = new Matrix(this.b[i13].getRows(), this.b[i13].getColumns());
                    this.b[i13].subtractAndCopy(matrixArr4[i13], matrix3, matrixArr4[i13]);
                    matrix3.multiplyAndAdd(d3, this.b[i13], this.b[i13]);
                    this.b[i13].subtract(this.d[i13], this.b[i13]);
                }
            }
            error = d4 / i3;
            if (this.plotter != null) {
                this.plotter.setError(i, error);
            }
            if (this.progress != null) {
                this.progress.setValue(-i);
            }
        }
    }

    private void resilentBP(Matrix[] matrixArr, Matrix[] matrixArr2, double d, int i, int i2, int i3, double d2) {
        int[] iArr = new int[i3];
        double error = error(matrixArr, matrixArr2, i2, i3);
        Matrix matrix = new Matrix(matrixArr2[0].getRows(), matrixArr2[0].getColumns());
        for (int i4 = 0; i4 < i3; i4++) {
            iArr[i4] = i4;
        }
        Matrix[] matrixArr3 = new Matrix[this.layers];
        Matrix[] matrixArr4 = new Matrix[this.layers];
        Matrix[] matrixArr5 = new Matrix[this.layers];
        Matrix[] matrixArr6 = new Matrix[this.layers];
        Matrix[] matrixArr7 = new Matrix[this.layers];
        Matrix[] matrixArr8 = new Matrix[this.layers];
        for (int i5 = 1; i5 < this.layers; i5++) {
            matrixArr3[i5] = new Matrix(this.u[i5].getRows(), this.Y[i5 - 1].getRows());
            matrixArr4[i5] = new Matrix(this.u[i5].getRows(), this.Y[i5 - 1].getRows());
            matrixArr5[i5] = new Matrix(this.u[i5].getRows(), this.Y[i5 - 1].getRows());
            matrixArr5[i5].setValue(0.1d);
            matrixArr6[i5] = new Matrix(this.nperlayer[i5], 1);
            matrixArr7[i5] = new Matrix(this.nperlayer[i5], 1);
            matrixArr8[i5] = new Matrix(this.nperlayer[i5], 1);
            matrixArr8[i5].setValue(0.1d);
        }
        while (error > d2) {
            int i6 = i;
            i--;
            if (i6 <= 0) {
                return;
            }
            double d3 = 0.0d;
            for (int i7 = 0; i7 < i3; i7++) {
                simulate(matrixArr[iArr[i7] + i2]);
                matrixArr2[iArr[i7] + i2].subtract(this.Y[this.layers - 1], matrix);
                for (int i8 = 0; i8 < this.nperlayer[this.layers - 1]; i8++) {
                    d3 += matrix.position(i8, 0) * matrix.position(i8, 0);
                }
                for (int i9 = 0; i9 < this.u[this.layers - 1].getRows(); i9++) {
                    this.d[this.layers - 1].position(i9, 0, (-2.0d) * this.func[this.layers - 1].getDerivate().eval(this.u[this.layers - 1].position(i9, 0)) * matrix.position(i9, 0));
                }
                for (int i10 = this.layers - 2; i10 > 0; i10--) {
                    for (int i11 = 0; i11 < this.u[i10].getRows(); i11++) {
                        double d4 = 0.0d;
                        for (int i12 = 0; i12 < this.W[i10 + 1].getRows(); i12++) {
                            d4 += this.W[i10 + 1].position(i12, i11) * this.d[i10 + 1].position(i12, 0);
                        }
                        this.d[i10].position(i11, 0, d4 * this.func[i10].getDerivate().eval(this.u[i10].position(i11, 0)));
                    }
                }
                for (int i13 = 1; i13 < this.layers; i13++) {
                    this.Y[i13 - 1].transpose(this.Yt[i13 - 1]);
                    this.d[i13].multiply(this.Yt[i13 - 1], this.M[i13]);
                    matrixArr3[i13].add(this.M[i13], matrixArr3[i13]);
                    matrixArr6[i13].add(this.d[i13], matrixArr6[i13]);
                }
            }
            for (int i14 = 1; i14 < this.layers; i14++) {
                for (int i15 = 0; i15 < this.W[i14].getRows(); i15++) {
                    for (int i16 = 0; i16 < this.W[i14].getColumns(); i16++) {
                        double position = matrixArr3[i14].position(i15, i16) * matrixArr4[i14].position(i15, i16);
                        double d5 = matrixArr3[i14].position(i15, i16) > 0.0d ? 1.0d : -1.0d;
                        if (position > 0.0d) {
                            matrixArr5[i14].position(i15, i16, Math.min(matrixArr5[i14].position(i15, i16) * 1.2d, 50.0d));
                            this.W[i14].position(i15, i16, this.W[i14].position(i15, i16) + ((-d5) * matrixArr5[i14].position(i15, i16)));
                            matrixArr4[i14].position(i15, i16, matrixArr3[i14].position(i15, i16));
                        } else if (position < 0.0d) {
                            matrixArr5[i14].position(i15, i16, Math.max(matrixArr5[i14].position(i15, i16) * 0.5d, 1.0E-6d));
                            matrixArr4[i14].position(i15, i16, 0.0d);
                        } else {
                            this.W[i14].position(i15, i16, this.W[i14].position(i15, i16) + ((-d5) * matrixArr5[i14].position(i15, i16)));
                            matrixArr4[i14].position(i15, i16, matrixArr3[i14].position(i15, i16));
                        }
                        matrixArr3[i14].position(i15, i16, 0.0d);
                    }
                    for (int i17 = 0; i17 < this.b[i14].getColumns(); i17++) {
                        double position2 = matrixArr6[i14].position(i15, i17) * matrixArr7[i14].position(i15, i17);
                        double d6 = matrixArr6[i14].position(i15, i17) > 0.0d ? 1.0d : -1.0d;
                        if (position2 > 0.0d) {
                            matrixArr8[i14].position(i15, i17, Math.min(matrixArr8[i14].position(i15, i17) * 1.2d, 50.0d));
                            this.b[i14].position(i15, i17, this.b[i14].position(i15, i17) + ((-d6) * matrixArr8[i14].position(i15, i17)));
                            matrixArr7[i14].position(i15, i17, matrixArr6[i14].position(i15, i17));
                        } else if (position2 < 0.0d) {
                            matrixArr8[i14].position(i15, i17, Math.max(matrixArr8[i14].position(i15, i17) * 0.5d, 1.0E-6d));
                            matrixArr7[i14].position(i15, i17, 0.0d);
                        } else {
                            this.b[i14].position(i15, i17, this.b[i14].position(i15, i17) + ((-d6) * matrixArr8[i14].position(i15, i17)));
                            matrixArr7[i14].position(i15, i17, matrixArr6[i14].position(i15, i17));
                        }
                        matrixArr6[i14].position(i15, i17, 0.0d);
                    }
                }
            }
            error = d3 / i3;
            if (this.plotter != null) {
                this.plotter.setError(i, error);
            }
            if (this.progress != null) {
                this.progress.setValue(-i);
            }
        }
    }

    @Override // libai.nn.NeuralNetwork
    public Matrix simulate(Matrix matrix) {
        simulate(matrix, null);
        return this.Y[this.layers - 1];
    }

    @Override // libai.nn.NeuralNetwork
    public void simulate(Matrix matrix, Matrix matrix2) {
        matrix.copy(this.Y[0]);
        for (int i = 1; i < this.layers; i++) {
            this.W[i].multiply(this.Y[i - 1], this.u[i]);
            this.u[i].add(this.b[i], this.u[i]);
            this.u[i].apply(this.func[i], this.Y[i]);
        }
        if (matrix2 != null) {
            this.Y[this.layers - 1].copy(matrix2);
        }
    }

    public static MLP open(String str) {
        try {
            FileInputStream fileInputStream = new FileInputStream(str);
            Throwable th = null;
            try {
                ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream);
                Throwable th2 = null;
                try {
                    try {
                        MLP mlp = (MLP) objectInputStream.readObject();
                        if (objectInputStream != null) {
                            if (0 != 0) {
                                try {
                                    objectInputStream.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                objectInputStream.close();
                            }
                        }
                        return mlp;
                    } finally {
                    }
                } catch (Throwable th4) {
                    if (objectInputStream != null) {
                        if (th2 != null) {
                            try {
                                objectInputStream.close();
                            } catch (Throwable th5) {
                                th2.addSuppressed(th5);
                            }
                        } else {
                            objectInputStream.close();
                        }
                    }
                    throw th4;
                }
            } finally {
                if (fileInputStream != null) {
                    if (0 != 0) {
                        try {
                            fileInputStream.close();
                        } catch (Throwable th6) {
                            th.addSuppressed(th6);
                        }
                    } else {
                        fileInputStream.close();
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }
}
