/** * Weighted instances. * * @author Shi-Huai Wen Email: shihuaiwen@outlook.com. */ publicclassWeightedInstancesextendsInstances { /** * Just the requirement of some classes, any number is ok. */ privatestaticfinallongserialVersionUID=11087456L;
/** * Weights. */ privatefinaldouble[] weights;
/** * ***************** * The first constructor. * * @param paraFileReader The given reader to read data from file. * ***************** */ publicWeightedInstances(FileReader paraFileReader)throws Exception { super(paraFileReader); setClassIndex(numAttributes() - 1);
Arrays.fill(weights, tempAverage); System.out.println("Instances weights are: " + Arrays.toString(weights)); } // Of the second constructor
/** * ***************** * Getter. * * @param paraIndex The given index. * @return The weight of the given index. * ***************** */ publicdoublegetWeight(int paraIndex) { return weights[paraIndex]; } // Of getWeight
/** * ***************** * Adjust the weights. * * @param paraCorrectArray Indicate which instances have been correctly classified. * @param paraAlpha The weight of the last classifier. * ***************** */ publicvoidadjustWeights(boolean[] paraCorrectArray, double paraAlpha) { // Step 1. Calculate alpha. doubletempIncrease= Math.exp(paraAlpha);
// Step 2. Adjust. doubletempWeightsSum=0; // For normalization. for (inti=0; i < weights.length; i++) { if (paraCorrectArray[i]) { weights[i] /= tempIncrease; } else { weights[i] *= tempIncrease; } // Of if tempWeightsSum += weights[i]; } // Of for i
// Step 3. Normalize. for (inti=0; i < weights.length; i++) { weights[i] /= tempWeightsSum; } // Of for i
/** * ***************** * Test the method. * ***************** */ publicvoidadjustWeightsTest() { boolean[] tempCorrectArray = newboolean[numInstances()]; for (inti=0; i < tempCorrectArray.length / 2; i++) { tempCorrectArray[i] = true; } // Of for i
/** * ***************** * For display. * ***************** */ public String toString() {
return"I am a weighted Instances object.\r\n" + "I have " + numInstances() + " instances and " + (numAttributes() - 1) + " conditional attributes.\r\n" + "My weights are: " + Arrays.toString(weights) + "\r\n" + "My data are: \r\n" + super.toString(); } // Of toString
/** * ***************** * For unit test. * * @param args Not provided. * ***************** */ publicstaticvoidmain(String[] args) { WeightedInstancestempWeightedInstances=null; StringtempFilename="D:/Work/sampledata/iris.arff"; try { FileReadertempFileReader=newFileReader(tempFilename); tempWeightedInstances = newWeightedInstances(tempFileReader); tempFileReader.close(); } catch (Exception exception1) { System.out.println("Cannot read the file: " + tempFilename + "\r\n" + exception1); System.exit(0); } // Of try
/** * The super class of any simple classifier. * * @author Shi-Huai Wen Email: shihuaiwen@outlook.com. */ publicabstractclassSimpleClassifier { /** * The index of the current attribute. */ int selectedAttribute;
/** * ***************** * Classify an instance. * * @param paraInstance The given instance. * @return Predicted label. * ***************** */ publicabstractintclassify(Instance paraInstance);
/** * ***************** * Which instances in the training set are correctly classified. * * @return The correctness array. * ***************** */ publicboolean[] computeCorrectnessArray() { boolean[] resultCorrectnessArray = newboolean[weightedInstances.numInstances()]; for (inti=0; i < resultCorrectnessArray.length; i++) { InstancetempInstance= weightedInstances.instance(i); if ((int) (tempInstance.classValue()) == classify(tempInstance)) { resultCorrectnessArray[i] = true; } // Of if
// System.out.print("\t" + classify(tempInstance)); } // Of for i // System.out.println(); return resultCorrectnessArray; }// Of computeCorrectnessArray
/** * ***************** * Compute the accuracy on the training set. * * @return The training accuracy. * ***************** */ publicdoublecomputeTrainingAccuracy() { doubletempCorrect=0; boolean[] tempCorrectnessArray = computeCorrectnessArray(); for (boolean b : tempCorrectnessArray) { if (b) { tempCorrect++; } // Of if } // Of for i
return tempCorrect / tempCorrectnessArray.length; }// Of computeTrainingAccuracy
/** * ***************** * Compute the weighted error on the training set. It is at least 1e-6 to * avoid NaN. * * @return The weighted error. * ***************** */ publicdoublecomputeWeightedError() { doubleresultError=0; boolean[] tempCorrectnessArray = computeCorrectnessArray(); for (inti=0; i < tempCorrectnessArray.length; i++) { if (!tempCorrectnessArray[i]) { resultError += weightedInstances.getWeight(i); } // Of if } // Of for i
if (resultError < 1e-6) { resultError = 1e-6; } // Of if
return resultError; }// Of computeWeightedError } //Of class SimpleClassifier
/** * The stump classifier. * * @author Shi-Huai Wen Email: shihuaiwen@outlook.com. */ publicclassStumpClassifierextendsSimpleClassifier { /** * The best cut for the current attribute on weightedInstances. */ double bestCut;
/** * The class label for attribute value less than bestCut. */ int leftLeafLabel;
/** * The class label for attribute value no less than bestCut. */ int rightLeafLabel;
/** * ***************** * The only constructor. * * @param paraWeightedInstances The given instances. * ***************** */ publicStumpClassifier(WeightedInstances paraWeightedInstances) { super(paraWeightedInstances); }// Of the only constructor
// Step 2. Find all attribute values and sort. double[] tempValuesArray = newdouble[numInstances]; for (inti=0; i < tempValuesArray.length; i++) { tempValuesArray[i] = weightedInstances.instance(i).value(selectedAttribute); } // Of for i Arrays.sort(tempValuesArray);
// Step 3. Initialize, classify all instances as the same with the // original cut. inttempNumLabels= numClasses; double[] tempLabelCountArray = newdouble[tempNumLabels]; int tempCurrentLabel;
// Step 3.1 Scan all labels to obtain their counts. for (inti=0; i < numInstances; i++) { // The label of the ith instance tempCurrentLabel = (int) weightedInstances.instance(i).classValue(); tempLabelCountArray[tempCurrentLabel] += weightedInstances.getWeight(i); } // Of for i
// Step 3.2 Find the label with the maximal count. doubletempMaxCorrect=0; inttempBestLabel= -1; for (inti=0; i < tempLabelCountArray.length; i++) { if (tempMaxCorrect < tempLabelCountArray[i]) { tempMaxCorrect = tempLabelCountArray[i]; tempBestLabel = i; } // Of if } // Of for i
// Step 3.3 The cut is a little smaller than the minimal value. bestCut = tempValuesArray[0] - 0.1; leftLeafLabel = tempBestLabel; rightLeafLabel = tempBestLabel;
// Step 4. Check candidate cuts one by one. // Step 4.1 To handle multi-class data, left and right. double tempCut; double[][] tempLabelCountMatrix = newdouble[2][tempNumLabels];
for (inti=0; i < tempValuesArray.length - 1; i++) { // Step 4.1 Some attribute values are identical, ignore them. if (tempValuesArray[i] == tempValuesArray[i + 1]) { continue; } // Of if tempCut = (tempValuesArray[i] + tempValuesArray[i + 1]) / 2;
// Step 4.2 Scan all labels to obtain their counts wrt. the cut. // Initialize again since it is used many times. for (intj=0; j < 2; j++) { for (intk=0; k < tempNumLabels; k++) { tempLabelCountMatrix[j][k] = 0; } // Of for k } // Of for j
for (intj=0; j < numInstances; j++) { // The label of the jth instance tempCurrentLabel = (int) weightedInstances.instance(j).classValue(); if (weightedInstances.instance(j).value(selectedAttribute) < tempCut) { tempLabelCountMatrix[0][tempCurrentLabel] += weightedInstances.getWeight(j); } else { tempLabelCountMatrix[1][tempCurrentLabel] += weightedInstances.getWeight(j); } // Of if } // Of for i
// Step 4.3 Left leaf. doubletempLeftMaxCorrect=0; inttempLeftBestLabel=0; for (intj=0; j < tempLabelCountMatrix[0].length; j++) { if (tempLeftMaxCorrect < tempLabelCountMatrix[0][j]) { tempLeftMaxCorrect = tempLabelCountMatrix[0][j]; tempLeftBestLabel = j; } // Of if } // Of for i
// Step 4.4 Right leaf. doubletempRightMaxCorrect=0; inttempRightBestLabel=0; for (intj=0; j < tempLabelCountMatrix[1].length; j++) { if (tempRightMaxCorrect < tempLabelCountMatrix[1][j]) { tempRightMaxCorrect = tempLabelCountMatrix[1][j]; tempRightBestLabel = j; } // Of if } // Of for i
// Step 4.5 Compare with the current best. if (tempMaxCorrect < tempLeftMaxCorrect + tempRightMaxCorrect) { tempMaxCorrect = tempLeftMaxCorrect + tempRightMaxCorrect; bestCut = tempCut; leftLeafLabel = tempLeftBestLabel; rightLeafLabel = tempRightBestLabel; } // Of if } // Of for i
/** * ***************** * Classify an instance. * * @param paraInstance The given instance. * @return Predicted label. * ***************** */ publicintclassify(Instance paraInstance) { intresultLabel= -1; if (paraInstance.value(selectedAttribute) < bestCut) { resultLabel = leftLeafLabel; } else { resultLabel = rightLeafLabel; } // Of if return resultLabel; }// Of classify
/** * ***************** * For display. * ***************** */ public String toString() {
return"I am a stump classifier.\r\n" + "I choose attribute #" + selectedAttribute + " with cut value " + bestCut + ".\r\n" + "The left and right leaf labels are " + leftLeafLabel + " and " + rightLeafLabel + ", respectively.\r\n" + "My weighted error is: " + computeWeightedError() + ".\r\n" + "My weighted accuracy is : " + computeTrainingAccuracy() + "."; }// Of toString
/** * ***************** * For unit test. * * @param args Not provided. * ***************** */ publicstaticvoidmain(String[] args) { WeightedInstancestempWeightedInstances=null; StringtempFilename="D:/Work/sampledata/iris.arff"; try { FileReadertempFileReader=newFileReader(tempFilename); tempWeightedInstances = newWeightedInstances(tempFileReader); tempFileReader.close(); } catch (Exception ee) { System.out.println("Cannot read the file: " + tempFilename + "\r\n" + ee); System.exit(0); } // Of try
/** * The booster which ensembles base classifiers. * * @author Shi-Huai Wen Email: shihuaiwen@outlook.com. */ publicclassBooster { /** * Classifiers. */ SimpleClassifier[] classifiers;
/** * Number of classifiers. */ int numClassifiers;
/** * Whether or not stop after the training error is 0. */ booleanstopAfterConverge=false;
/** * The weights of classifiers. */ double[] classifierWeights;
/** * The training data. */ Instances trainingData;
/** * The testing data. */ Instances testingData;
/** * ***************** * The first constructor. The testing set is the same as the training set. * * @param paraTrainingFilename The data filename. * ***************** */ publicBooster(String paraTrainingFilename) { // Step 1. Read training set. try { FileReadertempFileReader=newFileReader(paraTrainingFilename); trainingData = newInstances(tempFileReader); tempFileReader.close(); } catch (Exception ee) { System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee); System.exit(0); } // Of try
// Step 2. Set the last attribute as the class index. trainingData.setClassIndex(trainingData.numAttributes() - 1);
// Step 3. The testing data is the same as the training data. testingData = trainingData;
stopAfterConverge = true;
System.out.println("****************Data**********\r\n" + trainingData); }// Of the first constructor
/** * ***************** * Set the number of base classifier, and allocate space for them. * * @param paraNumBaseClassifiers The number of base classifier. * ***************** */ publicvoidsetNumBaseClassifiers(int paraNumBaseClassifiers) { numClassifiers = paraNumBaseClassifiers;
// Step 1. Allocate space (only reference) for classifiers classifiers = newSimpleClassifier[numClassifiers];
// Step 2. Build other classifiers. for (inti=0; i < classifiers.length; i++) { // Step 2.1 Key code: Construct or adjust the weightedInstances if (i == 0) { tempWeightedInstances = newWeightedInstances(trainingData); } else { // Adjust the weights of the data. tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(), classifierWeights[i - 1]); } // Of if
// Step 2.2 Train the next classifier. classifiers[i] = newStumpClassifier(tempWeightedInstances); classifiers[i].train();
// The accuracy is enough. if (stopAfterConverge) { doubletempTrainingAccuracy= computeTrainingAccuray(); System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n"); if (tempTrainingAccuracy > 0.999999) { System.out.println("Stop at the round: " + i + " due to converge.\r\n"); break; } // Of if } // Of if } // Of for i }// Of train
/** * ***************** * Classify an instance. * * @param paraInstance The given instance. * @return The predicted label. * ***************** */ publicintclassify(Instance paraInstance) { double[] tempLabelsCountArray = newdouble[trainingData.classAttribute().numValues()]; for (inti=0; i < numClassifiers; i++) { inttempLabel= classifiers[i].classify(paraInstance); tempLabelsCountArray[tempLabel] += classifierWeights[i]; } // Of for i
intresultLabel= -1; doubletempMax= -1; for (inti=0; i < tempLabelsCountArray.length; i++) { if (tempMax < tempLabelsCountArray[i]) { tempMax = tempLabelsCountArray[i]; resultLabel = i; } // Of if } // Of for
return resultLabel; }// Of classify
/** * ***************** * Test the booster on the training data. * * @return The classification accuracy. * ***************** */ publicdoubletest() { System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");
return test(testingData); }// Of test
/** * ***************** * Test the booster. * * @param paraInstances The testing set. * @return The classification accuracy. * ***************** */ publicdoubletest(Instances paraInstances) { doubletempCorrect=0; paraInstances.setClassIndex(paraInstances.numAttributes() - 1);
for (inti=0; i < paraInstances.numInstances(); i++) { InstancetempInstance= paraInstances.instance(i); if (classify(tempInstance) == (int) tempInstance.classValue()) { tempCorrect++; } // Of if } // Of for i
doubleresultAccuracy= tempCorrect / paraInstances.numInstances(); System.out.println("The accuracy is: " + resultAccuracy);
return resultAccuracy; } // Of test
/** * ***************** * Compute the training accuracy of the booster. It is not weighted. * * @return The training accuracy. * ***************** */ publicdoublecomputeTrainingAccuray() { doubletempCorrect=0;
for (inti=0; i < trainingData.numInstances(); i++) { if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) { tempCorrect++; } // Of if } // Of for i
return tempCorrect / trainingData.numInstances(); }// Of computeTrainingAccuracy