package demos.nn;

import demos.common.SimpleProgressDisplay;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.DefaultComboBoxModel;
import javax.swing.GroupLayout;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JPanel;
import javax.swing.JProgressBar;
import javax.swing.JScrollPane;
import javax.swing.JTextPane;
import javax.swing.LayoutStyle;
import libai.common.Matrix;
import libai.common.functions.Function;
import libai.common.functions.Identity;
import libai.common.functions.Sigmoid;
import libai.nn.supervised.MLP;

/* loaded from: input_file:demos/nn/MPLPanel.class */
public class MPLPanel extends JPanel {
    private JComboBox algorithmType;
    private JButton jButton1;
    private JProgressBar jProgressBar1;
    private JScrollPane jScrollPane1;
    private JTextPane jTextPane1;

    static double f(double d) {
        return Math.sin(d) + Math.cos(d);
    }

    public MPLPanel() {
        initComponents();
    }

    private void initComponents() {
        this.jProgressBar1 = new JProgressBar();
        this.jButton1 = new JButton();
        this.jScrollPane1 = new JScrollPane();
        this.jTextPane1 = new JTextPane();
        this.algorithmType = new JComboBox();
        this.jProgressBar1.setString("training");
        this.jProgressBar1.setStringPainted(true);
        this.jButton1.setText("Train");
        this.jButton1.addActionListener(new ActionListener() { // from class: demos.nn.MPLPanel.1
            public void actionPerformed(ActionEvent actionEvent) {
                MPLPanel.this.jButton1ActionPerformed(actionEvent);
            }
        });
        this.jTextPane1.setText("Train a MLP network to learn the equation: sin(x) + cos(x) for x in [1, 41) using a spacing of 0.1 for training and 0.33 for test. The network has 3 layers of 1, 4 and 1 neurons and functions, identity, sigmoid and identity respectively  ");
        this.jScrollPane1.setViewportView(this.jTextPane1);
        this.algorithmType.setModel(new DefaultComboBoxModel(new String[]{"Standard Backpropagation", "Momentum Backpropagation", "Resilent Backpropagation"}));
        GroupLayout groupLayout = new GroupLayout(this);
        setLayout(groupLayout);
        groupLayout.setHorizontalGroup(groupLayout.createParallelGroup(GroupLayout.Alignment.LEADING).addGroup(groupLayout.createSequentialGroup().addContainerGap().addGroup(groupLayout.createParallelGroup(GroupLayout.Alignment.LEADING).addComponent(this.jScrollPane1, -1, 380, 32767).addGroup(GroupLayout.Alignment.TRAILING, groupLayout.createSequentialGroup().addComponent(this.jProgressBar1, -1, 298, 32767).addPreferredGap(LayoutStyle.ComponentPlacement.UNRELATED).addComponent(this.jButton1)).addComponent(this.algorithmType, 0, 380, 32767)).addContainerGap()));
        groupLayout.setVerticalGroup(groupLayout.createParallelGroup(GroupLayout.Alignment.LEADING).addGroup(GroupLayout.Alignment.TRAILING, groupLayout.createSequentialGroup().addContainerGap().addComponent(this.jScrollPane1, -1, 210, 32767).addPreferredGap(LayoutStyle.ComponentPlacement.RELATED).addComponent(this.algorithmType, -2, -1, -2).addGap(11, 11, 11).addGroup(groupLayout.createParallelGroup(GroupLayout.Alignment.BASELINE).addComponent(this.jProgressBar1, -2, -1, -2).addComponent(this.jButton1)).addContainerGap()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void jButton1ActionPerformed(ActionEvent actionEvent) {
        this.jTextPane1.setText("");
        new Thread(new Runnable() { // from class: demos.nn.MPLPanel.2
            @Override // java.lang.Runnable
            public void run() {
                Matrix[] matrixArr = new Matrix[40 + 12];
                Matrix[] matrixArr2 = new Matrix[40 + 12];
                double d = 0.0d;
                int i = 0;
                while (i < 40) {
                    matrixArr[i] = new Matrix(1, 1);
                    matrixArr2[i] = new Matrix(1, 1);
                    matrixArr[i].position(0, 0, d);
                    matrixArr2[i].position(0, 0, MPLPanel.f(d));
                    i++;
                    d += 0.1d;
                }
                int i2 = 40;
                for (double d2 = 0.0d; i2 < 40 + 12 && d2 < 4.0d; d2 += 0.33d) {
                    matrixArr[i2] = new Matrix(1, 1);
                    matrixArr2[i2] = new Matrix(1, 1);
                    matrixArr[i2].position(0, 0, d2);
                    matrixArr2[i2].position(0, 0, MPLPanel.f(d2));
                    i2++;
                }
                MLP mlp = new MLP(new int[]{1, 4, 1}, new Function[]{new Identity(), new Sigmoid(), new Identity()});
                if (MPLPanel.this.algorithmType.getSelectedIndex() == 0) {
                    mlp.setTrainingType(0, new double[0]);
                } else if (MPLPanel.this.algorithmType.getSelectedIndex() == 1) {
                    mlp.setTrainingType(1, 0.4d);
                } else {
                    mlp.setTrainingType(2, new double[0]);
                }
                mlp.setProgressBar(new SimpleProgressDisplay(MPLPanel.this.jProgressBar1));
                long currentTimeMillis = System.currentTimeMillis();
                mlp.train(matrixArr, matrixArr2, 0.2d, 50000, 0, 40);
                long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                MPLPanel.this.jTextPane1.setText(MPLPanel.this.jTextPane1.getText() + MPLPanel.this.algorithmType.getSelectedItem() + "\n");
                MPLPanel.this.jTextPane1.setText(MPLPanel.this.jTextPane1.getText() + "Time taked: " + (currentTimeMillis2 / 1000) + " sec.\n");
                MPLPanel.this.jTextPane1.setText(MPLPanel.this.jTextPane1.getText() + "Error for training set: " + mlp.error(matrixArr, matrixArr2, 0, 40));
                MPLPanel.this.jTextPane1.setText(MPLPanel.this.jTextPane1.getText() + "\nError for test set: " + mlp.error(matrixArr, matrixArr2, 40, 12));
                MPLPanel.this.jTextPane1.setText(MPLPanel.this.jTextPane1.getText() + "\n\nValues for the test set:");
                for (int i3 = 40; i3 < matrixArr.length; i3++) {
                    MPLPanel.this.jTextPane1.setText(MPLPanel.this.jTextPane1.getText() + "\nexp: " + matrixArr2[i3].position(0, 0) + " vs " + mlp.simulate(matrixArr[i3]).position(0, 0));
                }
            }
        }).start();
    }
}
