knn in scala

  • nearest neighbor algorithm -- greedy


2选择cost最小的点D 重复


  • knn in scala --intuition
  • /** @author  wyq
     *  @version 1.0
     *  @date    Sun Sep 22 18:45:44 EDT 2013
    import util.control.Breaks.{breakable, break}
    import collection.mutable.Set
    import scalation.linalgebra.{MatrixD, VectorD}
    import scalation.linalgebra_gen.VectorN
    import scalation.linalgebra_gen.Vectors.VectorI
    import scalation.math.DoubleWithExp._
    import scalation.util.Error
     *  @param x    the vectors/points of classified data stored as rows of a matrix (also can be in List[Array[Double]])
     *  @param y    the classification of each vector in x
     *  @param fn   the names for all features/variables
     *  @param k    the number of classes
     *  @param cn   the names for all classes
     *  @param knn  the number of nearest neighbors to consider
    class KNN_Classifier (x: MatrixD, y: VectorI, fn: Array [String], k: Int, cn: Array [String],
                          knn: Int = 3)
          extends ClassifierReal (x, y, fn, k, cn)
        private val DEBUG      = true                                        // degug flag
        private val MAX_DOUBLE = Double.PositiveInfinity                     // infinity
        private val topK       = Array.ofDim [Tuple2 [Int, Double]] (knn)    // top-knn nearest points (in reserve order)    ofDim
        private val count      = new VectorI (k)                             // how many nearest neighbors in each class.
        /** Compute a distance metric between vectors/points u and v.
         *  @param u  the first vector/point
         *  @param v  the second vector/point
         *///always prepare the distance function
        def distance (u: VectorD, v: VectorD): Double =
            (u - v).normSq
        /** Find the knn nearest neighbors (top-knn) to vector ‘z‘.
         *  @param z  the vector to be classified
        def kNearest (z: VectorD)
            var dk = MAX_DOUBLE
            for (i <- 0 until x.dim1) {
                val di = distance (z, x(i))                   // compute distance to z
                if (di < dk) dk = replace (i, di)             // if closer, adjust top-knn
            if (DEBUG) println ("topK = " + topK.deep)
        /** Remove the most distant neighbor and add new neighbor ‘i‘.  Maintain the
         *  ‘topK‘ nearest neighbors in sorted order farthest to nearest.
        def replace (i: Int, di: Double): Double =
            var j = 0
            while (j < knn-1 && di < topK(j)._2) { topK(j) = topK(j+1); j += 1 }
            topK(j) = (i, di)
            topK(0)._2                      // the distance of the new farthest neighbor
        } // replace
        /** Training involves resetting the data structures before each classification.
         *  KNN uses lazy training, so most of it is done during classification.
        def train ()
            for (i <- 0 until knn) topK(i)  = (-1, MAX_DOUBLE)   // intialize top-knn
            for (j <- 0 until k) count(j) = 0                    // initilize counters
        } // train
        /** Given a new point/vector ‘z‘, determine which class it belongs to (i.e.,
         *  the class getting the most votes from its ‘knn‘ nearest neighbors.
         *  @param z  the vector to classify
        def classify (z: VectorD): Tuple2 [Int, String] =
            kNearest (z)                                         // set top-knn to knn nearest
            for (i <- 0 until knn) count(y(topK(i)._1)) += 1     // tally per class
            println ("count = " + count)
            val best = count.argmax ()                           // class with maximal count
            (best, cn(best))                                     // return the best class and its name
        } // classify
    } // KNN_Classifier class
    /** The `KNN_ClassifierTest` object is used to test the `KNN_Classifier` class.
    object KNN_ClassifierTest extends App
        val x = new MatrixD ((6, 2), 1.0, 2.0,      // data/feature matrix
                                     2.0, 1.0,
                                     5.0, 4.0,
                                     4.0, 5.0,
                                     9.0, 8.0,
                                     8.0, 9.0)
        val y  = VectorN (0, 0, 0, 1, 1, 1)         // classification for each vector in x
        val fn = Array ("x1", "x2")                 // feature/variable names
        val cn = Array ("No", "Yes")                // class names
        println ("----------------------------------------------------")
        println ("x = " + x)
        println ("y = " + y)
        val cl = new KNN_Classifier (x, y, fn, 2, cn)
        cl.train ()
        val z1 = VectorD (10.0, 10.0)
        println ("----------------------------------------------------")
        println ("z1 = " + z1)
        println ("class = " + cl.classify (z1))
        cl.train ()
        val z2 = VectorD ( 3.0,  3.0)
        println ("----------------------------------------------------")
        println ("z2 = " + z2)
        println ("class = " + cl.classify (z2))
    } // KNN_ClassifierTest object
时间: 2024-11-06 09:31:43

