- nearest neighbor algorithm -- greedy
1开始的点A(不同则答案不同)
2选择cost最小的点D 重复
3最后回到A,加总
- knn in scala --intuition
-
/** @author wyq * @version 1.0 * @date Sun Sep 22 18:45:44 EDT 2013 */
package scalation.analytics import util.control.Breaks.{breakable, break} import collection.mutable.Set import scalation.linalgebra.{MatrixD, VectorD} import scalation.linalgebra_gen.VectorN import scalation.linalgebra_gen.Vectors.VectorI import scalation.math.DoubleWithExp._ import scalation.util.Error
/* * @param x the vectors/points of classified data stored as rows of a matrix (also can be in List[Array[Double]]) * @param y the classification of each vector in x * @param fn the names for all features/variables * @param k the number of classes * @param cn the names for all classes * @param knn the number of nearest neighbors to consider */ class KNN_Classifier (x: MatrixD, y: VectorI, fn: Array [String], k: Int, cn: Array [String], knn: Int = 3) extends ClassifierReal (x, y, fn, k, cn) { private val DEBUG = true // degug flag private val MAX_DOUBLE = Double.PositiveInfinity // infinity private val topK = Array.ofDim [Tuple2 [Int, Double]] (knn) // top-knn nearest points (in reserve order) ofDim private val count = new VectorI (k) // how many nearest neighbors in each class. //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Compute a distance metric between vectors/points u and v. * @param u the first vector/point * @param v the second vector/point *///always prepare the distance function def distance (u: VectorD, v: VectorD): Double = { (u - v).normSq } //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Find the knn nearest neighbors (top-knn) to vector ‘z‘. * @param z the vector to be classified */ def kNearest (z: VectorD) { var dk = MAX_DOUBLE for (i <- 0 until x.dim1) { val di = distance (z, x(i)) // compute distance to z if (di < dk) dk = replace (i, di) // if closer, adjust top-knn } if (DEBUG) println ("topK = " + topK.deep) } //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Remove the most distant neighbor and add new neighbor ‘i‘. Maintain the * ‘topK‘ nearest neighbors in sorted order farthest to nearest. */ def replace (i: Int, di: Double): Double = { var j = 0 while (j < knn-1 && di < topK(j)._2) { topK(j) = topK(j+1); j += 1 } topK(j) = (i, di) topK(0)._2 // the distance of the new farthest neighbor } // replace //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Training involves resetting the data structures before each classification. * KNN uses lazy training, so most of it is done during classification. */ def train () { for (i <- 0 until knn) topK(i) = (-1, MAX_DOUBLE) // intialize top-knn for (j <- 0 until k) count(j) = 0 // initilize counters } // train //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** Given a new point/vector ‘z‘, determine which class it belongs to (i.e., * the class getting the most votes from its ‘knn‘ nearest neighbors. * @param z the vector to classify */ def classify (z: VectorD): Tuple2 [Int, String] = { kNearest (z) // set top-knn to knn nearest for (i <- 0 until knn) count(y(topK(i)._1)) += 1 // tally per class println ("count = " + count) val best = count.argmax () // class with maximal count (best, cn(best)) // return the best class and its name } // classify } // KNN_Classifier class //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: /** The `KNN_ClassifierTest` object is used to test the `KNN_Classifier` class. */ object KNN_ClassifierTest extends App { val x = new MatrixD ((6, 2), 1.0, 2.0, // data/feature matrix 2.0, 1.0, 5.0, 4.0, 4.0, 5.0, 9.0, 8.0, 8.0, 9.0) val y = VectorN (0, 0, 0, 1, 1, 1) // classification for each vector in x val fn = Array ("x1", "x2") // feature/variable names val cn = Array ("No", "Yes") // class names println ("----------------------------------------------------") println ("x = " + x) println ("y = " + y) val cl = new KNN_Classifier (x, y, fn, 2, cn) cl.train () val z1 = VectorD (10.0, 10.0) println ("----------------------------------------------------") println ("z1 = " + z1) println ("class = " + cl.classify (z1)) cl.train () val z2 = VectorD ( 3.0, 3.0) println ("----------------------------------------------------") println ("z2 = " + z2) println ("class = " + cl.classify (z2)) } // KNN_ClassifierTest object
时间: 2024-11-06 09:31:43