/** * The ID3 decision tree inductive algorithm. * * @author Shi-Huai Wen Email: shihuaiwen@outlook.com. */ publicclassID3 { /** * The data. */ Instances dataset;
/** * Is this dataset pure (only one label)? */ boolean pure;
/** * The number of classes. For binary classification it is 2. */ int numClasses;
/** * Available instances. Other instances do not belong this branch. */ int[] availableInstances;
/** * Available attributes. Other attributes have been selected in the path * from the root. */ int[] availableAttributes;
/** * The selected attribute. */ int splitAttribute;
/** * The children nodes. */ ID3[] children;
/** * My label. Inner nodes also have a label. For example, <outlook = sunny, * humidity = high> never appear in the training data, but <humidity = high> * is valid in other cases. */ int label;
/** * Small block cannot be split further. */ staticintsmallBlockThreshold=3;
/** * ******************* * The constructor. * * @param paraFilename The given file. * ******************* */ publicID3(String paraFilename) { dataset = null; try { FileReaderfileReader=newFileReader(paraFilename); dataset = newInstances(fileReader); fileReader.close(); } catch (Exception ee) { System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee); System.exit(0); } // Of try
availableInstances = newint[dataset.numInstances()]; for (inti=0; i < availableInstances.length; i++) { availableInstances[i] = i; } // Of for i availableAttributes = newint[dataset.numAttributes() - 1]; for (inti=0; i < availableAttributes.length; i++) { availableAttributes[i] = i; } // Of for i
// Initialize. children = null; // Determine the label by simple voting. label = getMajorityClass(availableInstances); // Determine whether or not it is pure. pure = pureJudge(availableInstances); }// Of the first constructor
/** * ******************* * The constructor. * * @param paraDataset The given dataset. * ******************* */ publicID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) { // Copy its reference instead of clone the availableInstances. dataset = paraDataset; availableInstances = paraAvailableInstances; availableAttributes = paraAvailableAttributes;
// Initialize. children = null; // Determine the label by simple voting. label = getMajorityClass(availableInstances); // Determine whether or not it is pure. pure = pureJudge(availableInstances); }// Of the second constructor
/** * ********************************* * Is the given block pure? * * @param paraBlock The block. * @return True if pure. * ********************************* */ publicbooleanpureJudge(int[] paraBlock) { pure = true;
// Just compare with 0 for (inti=1; i < paraBlock.length; i++) { if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) { pure = false; break; } // Of if } // Of for i
return pure; }// Of pureJudge
/** * ********************************* * Compute the majority class of the given block for voting. * * @param paraBlock The block. * @return The majority class. * ********************************* */ publicintgetMajorityClass(int[] paraBlock) { int[] tempClassCounts = newint[dataset.numClasses()]; for (int i : paraBlock) { tempClassCounts[(int) dataset.instance(i).classValue()]++; } // Of foreach
intresultMajorityClass= -1; inttempMaxCount= -1;
for (inti=0; i < tempClassCounts.length; i++) { if (tempMaxCount < tempClassCounts[i]) { resultMajorityClass = i; tempMaxCount = tempClassCounts[i]; } // Of if } // Of for i
return resultMajorityClass; }// Of getMajorityClass
/** * ********************************* * Select the best attribute. * * @return The best attribute index. * ********************************* */ publicintselectBestAttribute() { splitAttribute = -1; doubletempMinimalEntropy=10000; double tempEntropy; for (int availableAttribute : availableAttributes) { tempEntropy = conditionalEntropy(availableAttribute); if (tempMinimalEntropy > tempEntropy) { tempMinimalEntropy = tempEntropy; splitAttribute = availableAttribute; } // Of if } // Of foreach return splitAttribute; }// Of selectBestAttribute
/** * ********************************* * Compute the conditional entropy of an attribute. * * @param paraAttribute The given attribute. * @return The entropy. * ********************************* */ publicdoubleconditionalEntropy(int paraAttribute) { // Step 1. Statistics. inttempNumClasses= dataset.numClasses(); inttempNumValues= dataset.attribute(paraAttribute).numValues(); inttempNumInstances= availableInstances.length; double[] tempValueCounts = newdouble[tempNumValues]; double[][] tempCountMatrix = newdouble[tempNumValues][tempNumClasses];
int tempClass, tempValue; for (int availableInstance : availableInstances) { tempClass = (int) dataset.instance(availableInstance).classValue(); tempValue = (int) dataset.instance(availableInstance).value(paraAttribute); tempValueCounts[tempValue]++; tempCountMatrix[tempValue][tempClass]++; } // Of for i
// Step 2. doubleresultEntropy=0; double tempEntropy, tempFraction; for (inti=0; i < tempNumValues; i++) { if (tempValueCounts[i] == 0) { continue; } // Of if tempEntropy = 0; for (intj=0; j < tempNumClasses; j++) { tempFraction = tempCountMatrix[i][j] / tempValueCounts[i]; if (tempFraction == 0) { continue; } // Of if
// 信息熵越小, 信息的纯度越高, 信息量就越少 // H(X) = -p(x) * log p(x) tempEntropy += -tempFraction * Math.log(tempFraction); } // Of for j\ // 最小化条件信息熵 resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy; } // Of for i
return resultEntropy; }// Of conditionalEntropy
/** * ********************************* * Split the data according to the given attribute. * * @return The blocks. * ********************************* */ publicint[][] splitData(int paraAttribute) { inttempNumValues= dataset.attribute(paraAttribute).numValues();
// First scan to count the size of each block. int tempValue; for (int availableInstance : availableInstances) { tempValue = (int) dataset.instance(availableInstance).value(paraAttribute); tempSizes[tempValue]++; } // Of for i
// Allocate space. for (inti=0; i < tempNumValues; i++) { resultBlocks[i] = newint[tempSizes[i]]; } // Of for i
// Second scan to fill. Arrays.fill(tempSizes, 0); for (int availableInstance : availableInstances) { tempValue = (int) dataset.instance(availableInstance).value(paraAttribute); // Copy data. resultBlocks[tempValue][tempSizes[tempValue]] = availableInstance; tempSizes[tempValue]++; } // Of for i
return resultBlocks; }// Of splitData
/** * ********************************* * Build the tree recursively. * ********************************* */ publicvoidbuildTree() { // Is pure return. if (pureJudge(availableInstances)) { return; } // Of if
// Less than or equal to small block just return if (availableInstances.length <= smallBlockThreshold) { return; } // Of if
selectBestAttribute(); int[][] tempSubBlocks = splitData(splitAttribute); children = newID3[tempSubBlocks.length];
// Construct the remaining attribute set. int[] tempRemainingAttributes = newint[availableAttributes.length - 1]; for (inti=0; i < availableAttributes.length; i++) { if (availableAttributes[i] < splitAttribute) { tempRemainingAttributes[i] = availableAttributes[i]; } elseif (availableAttributes[i] > splitAttribute) { tempRemainingAttributes[i - 1] = availableAttributes[i]; } // Of if } // Of for i
// Construct children. for (inti=0; i < children.length; i++) { if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) { children[i] = null; } else { // System.out.println("Building children #" + i + " with // instances " + Arrays.toString(tempSubBlocks[i])); children[i] = newID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
// Important code: do this recursively children[i].buildTree(); } // Of if } // Of for i }// Of buildTree
/** * ********************************* * Classify an instance. * * @param paraInstance The given instance. * @return The prediction. * ********************************* */ publicintclassify(Instance paraInstance) { if (children == null) { return label; } // Of if
ID3tempChild= children[(int) paraInstance.value(splitAttribute)]; if (tempChild == null) { return label; } // Of if
return tempChild.classify(paraInstance); }// Of classify
/** * ********************************* * Test on a testing set. * * @param paraDataset The given testing data. * @return The accuracy. * ********************************* */ publicdoubletest(Instances paraDataset) { doubletempCorrect=0; for (inti=0; i < paraDataset.numInstances(); i++) { if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) { tempCorrect++; } // Of i } // Of for i
return tempCorrect / paraDataset.numInstances(); }// Of test
/** * ********************************* * Test on the training set. * * @return The accuracy. * ********************************* */ publicdoubleselfTest() { return test(dataset); }// Of selfTest
/** * ****************** * Overrides the method claimed in Object. * * @return The tree structure. * ****************** */ public String toString() { StringBuilderresultString=newStringBuilder(); StringtempAttributeName= dataset.attribute(splitAttribute).name(); if (children == null) { resultString.append("class = ").append(label); } else { for (inti=0; i < children.length; i++) { if (children[i] == null) { resultString.append(tempAttributeName) .append(" = ") .append(dataset.attribute(splitAttribute).value(i)) .append(":") .append("class = ") .append(label) .append("\r\n"); } else { resultString.append(tempAttributeName) .append(" = ") .append(dataset.attribute(splitAttribute).value(i)) .append(":") .append(children[i]) .append("\r\n"); } // Of if } // Of for i } // Of if
return resultString.toString(); }// Of toString
/** * ************************ * Test this class. * ************************ */ publicstaticvoidid3Test() { ID3tempID3=newID3("D:/Work/sampledata/weather.arff");
ID3.smallBlockThreshold = 3; tempID3.buildTree();
System.out.println("The tree is: \r\n" + tempID3);
doubletempAccuracy= tempID3.selfTest(); System.out.println("The accuracy is: " + tempAccuracy); }// Of id3Test
/** * ************************ * Test this class. * * @param args Not used now. * ************************ */ publicstaticvoidmain(String[] args) { id3Test(); }// Of main } // Of class ID3