package libai.classifiers.bayes;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import javax.xml.parsers.DocumentBuilderFactory;
import libai.classifiers.Attribute;
import libai.classifiers.ContinuousAttribute;
import libai.classifiers.DiscreteAttribute;
import libai.classifiers.dataset.DataSet;
import libai.classifiers.dataset.MetaData;
import libai.common.Pair;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

/* loaded from: input_file:libai/classifiers/bayes/NaiveBayes.class */
public class NaiveBayes {
    protected int outputIndex;
    protected int totalCount;
    protected MetaData metadata;
    protected HashMap<Attribute, Object[]> params;

    public NaiveBayes train(DataSet dataSet) {
        this.outputIndex = dataSet.getOutputIndex();
        this.totalCount = dataSet.getItemsCount();
        this.metadata = dataSet.getMetaData();
        this.params = new HashMap<>();
        initialize(dataSet);
        precalculate(dataSet);
        return this;
    }

    private void initialize(DataSet dataSet) {
        int attributeCount = this.metadata.getAttributeCount();
        for (Attribute attribute : this.metadata.getClasses()) {
            this.params.put(attribute, new Object[attributeCount]);
            for (int i = 0; i < attributeCount; i++) {
                if (i == this.outputIndex) {
                    this.params.get(attribute)[i] = 0;
                } else if (this.metadata.isCategorical(i)) {
                    this.params.get(attribute)[i] = new HashMap();
                } else {
                    this.params.get(attribute)[i] = new Pair(Double.valueOf(0.0d), Double.valueOf(0.0d));
                }
            }
        }
    }

    private void precalculate(DataSet dataSet) {
        for (List<Attribute> list : dataSet) {
            Attribute attribute = list.get(this.outputIndex);
            int i = 0;
            Iterator<Attribute> it = list.iterator();
            while (it.hasNext()) {
                Object value = it.next().getValue();
                if (i == this.outputIndex) {
                    this.params.get(attribute)[i] = Integer.valueOf(((Integer) this.params.get(attribute)[i]).intValue() + 1);
                } else if (this.metadata.isCategorical(i)) {
                    HashMap hashMap = (HashMap) this.params.get(attribute)[i];
                    if (hashMap.get((String) value) == null) {
                        hashMap.put((String) value, 0);
                    }
                    hashMap.put((String) value, Integer.valueOf(((Integer) hashMap.get((String) value)).intValue() + 1));
                } else {
                    Pair pair = (Pair) this.params.get(attribute)[i];
                    pair.first = Double.valueOf(((Double) pair.first).doubleValue() + ((Double) value).doubleValue());
                    pair.second = Double.valueOf(((Double) pair.second).doubleValue() + Math.pow(((Double) value).doubleValue(), 2.0d));
                }
                i++;
            }
        }
        for (Object[] objArr : this.params.values()) {
            for (Object obj : objArr) {
                if (obj instanceof Pair) {
                    Pair pair2 = (Pair) obj;
                    double doubleValue = ((Double) pair2.second).doubleValue();
                    double doubleValue2 = ((Double) pair2.first).doubleValue();
                    double intValue = ((Integer) objArr[this.outputIndex]).intValue();
                    pair2.second = Double.valueOf((doubleValue - ((doubleValue2 * doubleValue2) / intValue)) / (intValue - 1.0d));
                    pair2.first = Double.valueOf(doubleValue2 / intValue);
                }
            }
        }
    }

    public Attribute eval(List<Attribute> list) {
        Attribute attribute = null;
        double d = -1.7976931348623157E308d;
        for (Attribute attribute2 : this.params.keySet()) {
            double P = P(attribute2, list);
            if (P > d) {
                d = P;
                attribute = attribute2;
            }
        }
        return attribute;
    }

    private double P(Attribute attribute, List<Attribute> list) {
        return P(list, attribute) * P(attribute);
    }

    private double P(List<Attribute> list, Attribute attribute) {
        double d;
        double gaussian;
        double d2 = 1.0d;
        int size = list.size();
        for (int i = 0; i < size; i++) {
            Attribute attribute2 = list.get(i);
            if (this.metadata.isCategorical(i)) {
                d = d2;
                gaussian = (count((DiscreteAttribute) attribute2, i, attribute) + 1) / (((Integer) this.params.get(attribute)[this.outputIndex]).intValue() + 1);
            } else {
                d = d2;
                gaussian = gaussian((ContinuousAttribute) attribute2, i, attribute);
            }
            d2 = d * gaussian;
        }
        return d2;
    }

    private double P(Attribute attribute) {
        return (((Integer) this.params.get(attribute)[this.outputIndex]).intValue() + 1) / (this.totalCount + this.params.size());
    }

    private int count(DiscreteAttribute discreteAttribute, int i, Attribute attribute) {
        return ((Integer) ((HashMap) this.params.get(attribute)[i]).get(discreteAttribute.getValue())).intValue();
    }

    private double gaussian(ContinuousAttribute continuousAttribute, int i, Attribute attribute) {
        Pair pair = (Pair) this.params.get(attribute)[i];
        double doubleValue = ((Double) pair.first).doubleValue();
        double doubleValue2 = ((Double) pair.second).doubleValue();
        return Math.exp(-(Math.pow(continuousAttribute.getValue().doubleValue() - doubleValue, 2.0d) / (2.0d * doubleValue2))) * (1.0d / Math.sqrt(6.283185307179586d * doubleValue2));
    }

    public static NaiveBayes getInstance(File file) {
        try {
            return new NaiveBayes().load(DocumentBuilderFactory.newInstance().newDocumentBuilder().parse(new FileInputStream(file)).getElementsByTagName("NaiveBayes").item(0));
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public static NaiveBayes getInstance(DataSet dataSet) {
        return new NaiveBayes().train(dataSet);
    }

    /* JADX WARN: Failed to calculate best type for var: r6v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r6v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r7v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r7v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 6, insn: 0x0118: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r6 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:51:0x0118 */
    /* JADX WARN: Not initialized variable reg: 7, insn: 0x011c: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r7 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:53:0x011c */
    /* JADX WARN: Type inference failed for: r6v1, types: [java.io.FileOutputStream] */
    /* JADX WARN: Type inference failed for: r7v0, types: [java.lang.Throwable] */
    public boolean save(File file) {
        try {
            try {
                FileOutputStream fileOutputStream = new FileOutputStream(file);
                Throwable th = null;
                PrintStream printStream = new PrintStream(fileOutputStream);
                Throwable th2 = null;
                try {
                    try {
                        printStream.println("<?xml version=\"1.0\" encoding=\"utf-8\"?>");
                        printStream.println("<" + getClass().getSimpleName() + " outputIndex=\"" + this.outputIndex + "\" totalCount=\"" + this.totalCount + "\" attributes=\"" + this.metadata.getAttributeCount() + "\">");
                        save(printStream, "\t");
                        printStream.println("</" + getClass().getSimpleName() + ">");
                        if (printStream != null) {
                            if (0 != 0) {
                                try {
                                    printStream.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                printStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    fileOutputStream.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                fileOutputStream.close();
                            }
                        }
                        return true;
                    } finally {
                    }
                } catch (Throwable th5) {
                    if (printStream != null) {
                        if (th2 != null) {
                            try {
                                printStream.close();
                            } catch (Throwable th6) {
                                th2.addSuppressed(th6);
                            }
                        } else {
                            printStream.close();
                        }
                    }
                    throw th5;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    private void save(PrintStream printStream, String str) throws IOException {
        for (Attribute attribute : this.params.keySet()) {
            printStream.println(str + "<class>");
            printStream.println(str + "\t<params type=\"" + attribute.getClass().getName() + "\" name=\"" + attribute.getName() + "\" ><![CDATA[" + attribute.getValue() + "]]></params>");
            int i = 0;
            for (Object obj : this.params.get(attribute)) {
                printStream.println(str + "\t<attribute index=\"" + i + "\">");
                if (obj instanceof Integer) {
                    printStream.println(str + "\t\t<count>" + obj + "</count>");
                } else if (obj instanceof Pair) {
                    Pair pair = (Pair) obj;
                    printStream.println(str + "\t\t<stats mean=\"" + pair.first + "\" sd=\"" + pair.second + "\"/>");
                } else {
                    HashMap hashMap = (HashMap) obj;
                    for (String str2 : hashMap.keySet()) {
                        printStream.println(str + "\t\t<item count=\"" + hashMap.get(str2) + "\"><![CDATA[" + str2 + "]]></item>");
                    }
                }
                printStream.println(str + "\t</attribute>");
                i++;
            }
            printStream.println(str + "</class>");
        }
    }

    private NaiveBayes load(Node node) {
        this.outputIndex = Integer.parseInt(node.getAttributes().getNamedItem("outputIndex").getTextContent());
        this.totalCount = Integer.parseInt(node.getAttributes().getNamedItem("totalCount").getTextContent());
        this.params = new HashMap<>();
        int parseInt = Integer.parseInt(node.getAttributes().getNamedItem("attributes").getTextContent());
        NodeList childNodes = node.getChildNodes();
        for (int i = 0; i < childNodes.getLength(); i++) {
            Node item = childNodes.item(i);
            if (item.getNodeName().equals("class")) {
                NodeList childNodes2 = item.getChildNodes();
                Attribute attribute = null;
                for (int i2 = 0; i2 < childNodes2.getLength(); i2++) {
                    Node item2 = childNodes2.item(i2);
                    if (item2.getNodeName().equals("params")) {
                        attribute = Attribute.load(item2);
                    } else if (item2.getNodeName().equals("attribute")) {
                        int parseInt2 = Integer.parseInt(item2.getAttributes().getNamedItem("index").getTextContent());
                        if (this.params.get(attribute) == null) {
                            this.params.put(attribute, new Object[parseInt]);
                        }
                        this.params.get(attribute)[parseInt2] = getParams(item2);
                    }
                }
            }
        }
        System.err.println(this);
        return this;
    }

    private Object getParams(Node node) {
        NodeList childNodes = node.getChildNodes();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < childNodes.getLength(); i++) {
            Node item = childNodes.item(i);
            if (item.getNodeName().equals("count")) {
                return Integer.valueOf(Integer.parseInt(item.getTextContent()));
            }
            if (item.getNodeName().equals("stats")) {
                return new Pair(Double.valueOf(Double.parseDouble(item.getAttributes().getNamedItem("mean").getTextContent())), Double.valueOf(Double.parseDouble(item.getAttributes().getNamedItem("sd").getTextContent())));
            }
            if (item.getNodeName().equals("item")) {
                hashMap.put(item.getTextContent(), Integer.valueOf(Integer.parseInt(item.getAttributes().getNamedItem("count").getTextContent())));
            }
        }
        return hashMap;
    }
}
