KNN 的全称是K Nearest Neighbors, 意思是 K 个最近的邻居,
从这个名字我们就能看出一些 KNN 算法的蛛丝马迹了. K 个最近邻居, 毫无疑问,
K的取值肯定是至关重要的.那么最近的邻居又是怎么回事? 其实, KNN
的原理就是当预测一个新的值 x 的时候, 根据它距离最近的 K
个点是什么类别来判断 x 属于哪个类别.
1. 距离计算
有两种距离计算方式, 一种是欧式距离, 另一种是曼哈顿距离.
在本文中因为只使用了欧式距离那么我就在这个地方阐述一下.
这个式子其实就是之前求二维或者三维坐标系下两点距离的公式,
我们把这个公式推广到 n 维.
/** * The distance measure. */ publicintdistanceMeasure= EUCLIDEAN;
/** * A random instance; */ publicstaticfinalRandomrandom=newRandom();
/** * The number of neighbors. */ intnumNeighbors=7;
/** * The whole dataset. */ Instances dataset;
/** * The training set. Represented by the indices of the data. */ int[] trainingSet;
/** * The testing set. Represented by the indices of the data. */ int[] testingSet;
/** * The predictions. */ int[] predictions;
/** * ******************** * The first constructor. * * @param paraFilename The arff filename. * ******************** */ publicKnnClassification(String paraFilename) { try { FileReaderfileReader=newFileReader(paraFilename); dataset = newInstances(fileReader); // The last attribute is the decision class. dataset.setClassIndex(dataset.numAttributes() - 1); fileReader.close(); } catch (Exception ee) { System.out.println("Error occurred while trying to read ' " + paraFilename + " ' in KnnClassification constructor.\r\n" + ee); System.exit(0); } // Of try }// Of the first constructor
/** * ******************** * Get a random indices for data randomization. * * @param paraLength The length of the sequence. * @return An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6. * ******************** */ publicstaticint[] getRandomIndices(int paraLength) { int[] resultIndices = newint[paraLength];
// Step 1. Initialize. for (inti=0; i < paraLength; i++) { resultIndices[i] = i; } // Of for i
// Step 2. Randomly swap. int tempFirst, tempSecond, tempValue; for (inti=0; i < paraLength; i++) { // Generate two random indices. tempFirst = random.nextInt(paraLength); tempSecond = random.nextInt(paraLength);
// Swap. tempValue = resultIndices[tempFirst]; resultIndices[tempFirst] = resultIndices[tempSecond]; resultIndices[tempSecond] = tempValue; } // Of for i
return resultIndices; }// Of getRandomIndices
/** * ******************** * Split the data into training and testing parts. * * @param paraTrainingFraction The fraction of the training set. * ******************** */ publicvoidsplitTrainingTesting(double paraTrainingFraction) { inttempSize= dataset.numInstances(); int[] tempIndices = getRandomIndices(tempSize); inttempTrainingSize= (int) (tempSize * paraTrainingFraction);
if (tempSize - tempTrainingSize >= 0) { System.arraycopy(tempIndices, tempTrainingSize, testingSet, 0, tempSize - tempTrainingSize); } // Of for if
}// Of splitTrainingTesting
/** * ******************** * Predict for the whole testing set. The results are stored in predictions. * #see predictions. * ******************** */ publicvoidpredict() { predictions = newint[testingSet.length]; for (inti=0; i < predictions.length; i++) { predictions[i] = predict(testingSet[i]); } // Of for i }// Of predict
/** * ******************** * Predict for given instance. * * @return The prediction. * ******************** */ publicintpredict(int paraIndex) { int[] tempNeighbors = computeNearests(paraIndex); int resultPrediction; resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction; }// Of predict
/** * ******************** * The distance between two instances. * * @param paraI The index of the first instance. * @param paraJ The index of the second instance. * @return The distance. * ******************** */ publicdoubledistance(int paraI, int paraJ) { doubleresultDistance=0; double tempDifference; switch (distanceMeasure) { case MANHATTAN: for (inti=0; i < dataset.numAttributes() - 1; i++) { tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i); if (tempDifference < 0) { resultDistance -= tempDifference; } else { resultDistance += tempDifference; } // Of if } // Of for i break;
case EUCLIDEAN: for (inti=0; i < dataset.numAttributes() - 1; i++) { tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i); resultDistance += tempDifference * tempDifference; } // Of for i break; default: System.out.println("Unsupported distance measure: " + distanceMeasure); }// Of switch
return resultDistance; }// Of distance
/** * ******************** * Get the accuracy of the classifier. * * @return The accuracy. * ******************** */ publicdoublegetAccuracy() { // A double divides an int gets another double. doubletempCorrect=0; for (inti=0; i < predictions.length; i++) { if (predictions[i] == dataset.instance(testingSet[i]).classValue()) { tempCorrect++; } // Of if } // Of for i
return tempCorrect / testingSet.length; }// Of getAccuracy
/** * *********************************** * Compute the nearest k neighbors. Select one neighbor in each scan. In * fact, we can scan only once. You may implement it by yourself. * * @param paraCurrent current instance. We are comparing it with all others. * @return the indices of the nearest instances. * *********************************** */ publicint[] computeNearests(int paraCurrent) { int[] resultNearests = newint[numNeighbors]; boolean[] tempSelected = newboolean[trainingSet.length]; double tempMinimalDistance; inttempMinimalIndex=0;
// Compute all distances to avoid redundant computation. double[] tempDistances = newdouble[trainingSet.length]; for (inti=0; i < trainingSet.length; i++) { tempDistances[i] = distance(paraCurrent, trainingSet[i]); }//Of for i
// Select the nearest paraK indices. for (inti=0; i < numNeighbors; i++) { tempMinimalDistance = Double.MAX_VALUE;
for (intj=0; j < trainingSet.length; j++) { if (tempSelected[j]) { continue; } // Of if
if (tempDistances[j] < tempMinimalDistance) { tempMinimalDistance = tempDistances[j]; tempMinimalIndex = j; } // Of if } // Of for j
resultNearests[i] = trainingSet[tempMinimalIndex]; tempSelected[tempMinimalIndex] = true; } // Of for i
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests)); return resultNearests; }// Of computeNearests
/** * *********************************** * Voting using the instances. * * @param paraNeighbors The indices of the neighbors. * @return The predicted label. * *********************************** */ publicintsimpleVoting(int[] paraNeighbors) { int[] tempVotes = newint[dataset.numClasses()]; for (int paraNeighbor : paraNeighbors) { tempVotes[(int) dataset.instance(paraNeighbor).classValue()]++; } // Of for i
inttempMaximalVotingIndex=0; inttempMaximalVoting=0; for (inti=0; i < dataset.numClasses(); i++) { if (tempVotes[i] > tempMaximalVoting) { tempMaximalVoting = tempVotes[i]; tempMaximalVotingIndex = i; } // Of if } // Of for i
return tempMaximalVotingIndex; }// Of simpleVoting
/** * ******************** * The entrance of the program. * * @param args Not used now. * ******************** */ publicstaticvoidmain(String[] args) { KnnClassificationtempClassifier=newKnnClassification("D:/Work/sampledata/iris.arff"); tempClassifier.splitTrainingTesting(0.8); tempClassifier.predict(); System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy()); }// Of main