@Override public String classify(User user, Sample sample) { Instances trainingSet = new TrainingSetBuilder() .setAttributes(user.getBssids()) .setClassAttribute( "Location", user.getLocations().stream().map(Location::getName).collect(Collectors.toList())) .build("TrainingSet", 1); // Create instance Map<String, Integer> BSSIDLevelMap = getBSSIDLevelMap(sample); Instance instance = new Instance(trainingSet.numAttributes()); for (Enumeration e = trainingSet.enumerateAttributes(); e.hasMoreElements(); ) { Attribute attribute = (Attribute) e.nextElement(); String bssid = attribute.name(); int level = (BSSIDLevelMap.containsKey(bssid)) ? BSSIDLevelMap.get(bssid) : 0; instance.setValue(attribute, level); } if (sample.getLocation() != null) instance.setValue(trainingSet.classAttribute(), sample.getLocation()); instance.setDataset(trainingSet); trainingSet.add(instance); int predictedClass = classify(fromBase64(user.getClassifiers()), instance); return trainingSet.classAttribute().value(predictedClass); }
@Override public List<Classifier> buildClassifiers(User user, List<Sample> validSamples) { Instances trainingSet = new TrainingSetBuilder() .setAttributes(user.getBssids()) .setClassAttribute( "Location", user.getLocations().stream().map(Location::getName).collect(Collectors.toList())) .build("TrainingSet", validSamples.size()); // Create instances validSamples.forEach( sample -> { Map<String, Integer> BSSIDLevelMap = getBSSIDLevelMap(sample); Instance instance = new Instance(trainingSet.numAttributes()); for (Enumeration e = trainingSet.enumerateAttributes(); e.hasMoreElements(); ) { Attribute attribute = (Attribute) e.nextElement(); String bssid = attribute.name(); int level = (BSSIDLevelMap.containsKey(bssid)) ? BSSIDLevelMap.get(bssid) : 0; instance.setValue(attribute, level); } instance.setValue(trainingSet.classAttribute(), sample.getLocation()); instance.setDataset(trainingSet); trainingSet.add(instance); }); // Build classifiers List<Classifier> classifiers = buildClassifiers(trainingSet); return classifiers; }
/** * Compute the number of all possible conditions that could appear in a rule of a given data. For * nominal attributes, it's the number of values that could appear; for numeric attributes, it's * the number of values * 2, i.e. <= and >= are counted as different possible conditions. * * @param data the given data * @return number of all conditions of the data */ public static double numAllConditions(Instances data) { double total = 0; Enumeration attEnum = data.enumerateAttributes(); while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); if (att.isNominal()) total += (double) att.numValues(); else total += 2.0 * (double) data.numDistinctValues(att); } return total; }
/** * Zwraca list?? dost??pnych atrybut??w. * * @return Lista dost??pnych atrybut??w. */ public List<String> getAttributeNames() { Enumeration e = data.enumerateAttributes(); List<String> attributeNames = new ArrayList<String>(); while (e.hasMoreElements()) { attributeNames.add(((Attribute) e.nextElement()).name()); } return attributeNames; }
/** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ @Override public String toString() { if (m_Instances == null) { return "Naive Bayes (simple): No model built yet."; } try { StringBuffer text = new StringBuffer("Naive Bayes (simple)"); int attIndex; for (int i = 0; i < m_Instances.numClasses(); i++) { text.append( "\n\nClass " + m_Instances.classAttribute().value(i) + ": P(C) = " + Utils.doubleToString(m_Priors[i], 10, 8) + "\n\n"); Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); text.append("Attribute " + attribute.name() + "\n"); if (attribute.isNominal()) { for (int j = 0; j < attribute.numValues(); j++) { text.append(attribute.value(j) + "\t"); } text.append("\n"); for (int j = 0; j < attribute.numValues(); j++) { text.append(Utils.doubleToString(m_Counts[i][attIndex][j], 10, 8) + "\t"); } } else { text.append("Mean: " + Utils.doubleToString(m_Means[i][attIndex], 10, 8) + "\t"); text.append("Standard Deviation: " + Utils.doubleToString(m_Devs[i][attIndex], 10, 8)); } text.append("\n\n"); attIndex++; } } return text.toString(); } catch (Exception e) { return "Can't print Naive Bayes classifier!"; } }
/** * Method for building an Id3 tree. * * @param data the training data * @exception Exception if decision tree can't be built successfully */ private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node. if (data.numInstances() == 0) { m_Attribute = null; m_ClassValue = Utils.missingValue(); m_Distribution = new double[data.numClasses()]; return; } // Compute attribute with maximum information gain. double[] infoGains = new double[data.numAttributes()]; Enumeration attEnum = data.enumerateAttributes(); while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); infoGains[att.index()] = computeInfoGain(data, att); } m_Attribute = data.attribute(Utils.maxIndex(infoGains)); // Make leaf if information gain is zero. // Otherwise create successors. if (Utils.eq(infoGains[m_Attribute.index()], 0)) { m_Attribute = null; m_Distribution = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); m_Distribution[(int) inst.classValue()]++; } Utils.normalize(m_Distribution); m_ClassValue = Utils.maxIndex(m_Distribution); m_ClassAttribute = data.classAttribute(); } else { Instances[] splitData = splitData(data, m_Attribute); m_Successors = new Id3[m_Attribute.numValues()]; for (int j = 0; j < m_Attribute.numValues(); j++) { m_Successors[j] = new Id3(); m_Successors[j].makeTree(splitData[j]); } } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances instances) throws Exception { int attIndex = 0; double sum; // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_Instances = new Instances(instances, 0); // Reserve space m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0]; m_Means = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Priors = new double[instances.numClasses()]; Enumeration<Attribute> enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { Attribute attribute = enu.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[attribute.numValues()]; } } else { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[1]; } } attIndex++; } // Compute counts and sums Enumeration<Instance> enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { m_Counts[(int) instance.classValue()][attIndex][(int) instance.value(attribute)]++; } else { m_Means[(int) instance.classValue()][attIndex] += instance.value(attribute); m_Counts[(int) instance.classValue()][attIndex][0]++; } } attIndex++; } m_Priors[(int) instance.classValue()]++; } } // Compute means Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Counts[j][attIndex][0] < 2) { throw new Exception( "attribute " + attribute.name() + ": less than two values for class " + instances.classAttribute().value(j)); } m_Means[j][attIndex] /= m_Counts[j][attIndex][0]; } } attIndex++; } // Compute standard deviations enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNumeric()) { m_Devs[(int) instance.classValue()][attIndex] += (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)) * (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)); } } attIndex++; } } } enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Devs[j][attIndex] <= 0) { throw new Exception( "attribute " + attribute.name() + ": standard deviation is 0 for class " + instances.classAttribute().value(j)); } else { m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1; m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]); } } } attIndex++; } // Normalize counts enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { sum = Utils.sum(m_Counts[j][attIndex]); for (int i = 0; i < attribute.numValues(); i++) { m_Counts[j][attIndex][i] = (m_Counts[j][attIndex][i] + 1) / (sum + attribute.numValues()); } } } attIndex++; } // Normalize priors sum = Utils.sum(m_Priors); for (int j = 0; j < instances.numClasses(); j++) { m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses()); } }