K-Means算法

  1. // Macro to implement kmeans for both f64 and f32 without writing everything
  2. // twice or importing the `num` crate
  3. macro_rules! impl_kmeans {
  4. ($kind: ty, $modname: ident) => {
  5. // Since we can't overload methods in rust, we have to use namespacing
  6. pub mod $modname {
  7. use std::$modname::INFINITY;
  8. /// computes sum of squared deviation between two identically sized vectors
  9. /// `x`, and `y`.
  10. fn distance(x: &[$kind], y: &[$kind]) -> $kind {
  11. x.iter()
  12. .zip(y.iter())
  13. .fold(0.0, |dist, (&xi, &yi)| dist + (xi - yi).powi(2))
  14. }
  15. /// Returns a vector containing the indices z<sub>i</sub> in {0, ..., K-1} of
  16. /// the centroid nearest to each datum.
  17. fn nearest_centroids(xs: &[Vec<$kind>], centroids: &[Vec<$kind>]) -> Vec<usize> {
  18. xs.iter()
  19. .map(|xi| {
  20. // Find the argmin by folding using a tuple containing the argmin
  21. // and the minimum distance.
  22. let (argmin, _) = centroids.iter().enumerate().fold(
  23. (0_usize, INFINITY),
  24. |(min_ix, min_dist), (ix, ci)| {
  25. let dist = distance(xi, ci);
  26. if dist < min_dist {
  27. (ix, dist)
  28. } else {
  29. (min_ix, min_dist)
  30. }
  31. },
  32. );
  33. argmin
  34. })
  35. .collect()
  36. }
  37. /// Recompute the centroids given the current clustering
  38. fn recompute_centroids(
  39. xs: &[Vec<$kind>],
  40. clustering: &[usize],
  41. k: usize,
  42. ) -> Vec<Vec<$kind>> {
  43. let ndims = xs[0].len();
  44. // NOTE: Kind of inefficient because we sweep all the data from each of the
  45. // k centroids.
  46. (0..k)
  47. .map(|cluster_ix| {
  48. let mut centroid: Vec<$kind> = vec![0.0; ndims];
  49. let mut n_cluster: $kind = 0.0;
  50. xs.iter().zip(clustering.iter()).for_each(|(xi, &zi)| {
  51. if zi == cluster_ix {
  52. n_cluster += 1.0;
  53. xi.iter().enumerate().for_each(|(j, &x_ij)| {
  54. centroid[j] += x_ij;
  55. });
  56. }
  57. });
  58. centroid.iter().map(|&c_j| c_j / n_cluster).collect()
  59. })
  60. .collect()
  61. }
  62. /// Assign the N D-dimensional data, `xs`, to `k` clusters using K-Means clustering
  63. pub fn kmeans(xs: Vec<Vec<$kind>>, k: usize) -> Vec<usize> {
  64. assert!(xs.len() >= k);
  65. // Rather than pulling in a dependency to radomly select the staring
  66. // points for the centroids, we're going to deterministally choose them by
  67. // slecting evenly spaced points in `xs`
  68. let n_per_cluster: usize = xs.len() / k;
  69. let centroids: Vec<Vec<$kind>> =
  70. (0..k).map(|j| xs[j * n_per_cluster].clone()).collect();
  71. let mut clustering = nearest_centroids(&xs, &centroids);
  72. loop {
  73. let centroids = recompute_centroids(&xs, &clustering, k);
  74. let new_clustering = nearest_centroids(&xs, &centroids);
  75. // loop until the clustering doesn't change after the new centroids are computed
  76. if new_clustering
  77. .iter()
  78. .zip(clustering.iter())
  79. .all(|(&za, &zb)| za == zb)
  80. {
  81. // We need to use `return` to break out of the `loop`
  82. return clustering;
  83. } else {
  84. clustering = new_clustering;
  85. }
  86. }
  87. }
  88. }
  89. };
  90. }
  91. // generate code for kmeans for f32 and f64 data
  92. impl_kmeans!(f64, f64);
  93. impl_kmeans!(f32, f32);
  94. #[cfg(test)]
  95. mod test {
  96. use self::super::f64::kmeans;
  97. #[test]
  98. fn easy_univariate_clustering() {
  99. let xs: Vec<Vec<f64>> = vec![
  100. vec![-1.1],
  101. vec![-1.2],
  102. vec![-1.3],
  103. vec![-1.4],
  104. vec![1.1],
  105. vec![1.2],
  106. vec![1.3],
  107. vec![1.4],
  108. ];
  109. let clustering = kmeans(xs, 2);
  110. assert_eq!(clustering, vec![0, 0, 0, 0, 1, 1, 1, 1]);
  111. }
  112. #[test]
  113. fn easy_univariate_clustering_odd_number_of_data() {
  114. let xs: Vec<Vec<f64>> = vec![
  115. vec![-1.1],
  116. vec![-1.2],
  117. vec![-1.3],
  118. vec![-1.4],
  119. vec![1.1],
  120. vec![1.2],
  121. vec![1.3],
  122. vec![1.4],
  123. vec![1.5],
  124. ];
  125. let clustering = kmeans(xs, 2);
  126. assert_eq!(clustering, vec![0, 0, 0, 0, 1, 1, 1, 1, 1]);
  127. }
  128. #[test]
  129. fn easy_bivariate_clustering() {
  130. let xs: Vec<Vec<f64>> = vec![
  131. vec![-1.1, 0.2],
  132. vec![-1.2, 0.3],
  133. vec![-1.3, 0.1],
  134. vec![-1.4, 0.4],
  135. vec![1.1, -1.1],
  136. vec![1.2, -1.0],
  137. vec![1.3, -1.2],
  138. vec![1.4, -1.3],
  139. ];
  140. let clustering = kmeans(xs, 2);
  141. assert_eq!(clustering, vec![0, 0, 0, 0, 1, 1, 1, 1]);
  142. }
  143. #[test]
  144. fn high_dims() {
  145. let xs: Vec<Vec<f64>> = vec![
  146. vec![-2.7825343, -1.7604825, -5.5550113, -2.9752946, -2.7874138],
  147. vec![-2.9847919, -3.8209332, -2.1531757, -2.2710119, -2.3582877],
  148. vec![-3.0109320, -2.2366132, -2.8048492, -1.2632331, -4.5755581],
  149. vec![-2.8432186, -1.0383805, -2.2022826, -2.7435962, -2.0013399],
  150. vec![-2.6638082, -3.5520086, -1.3684702, -2.1562444, -1.3186447],
  151. vec![1.7409171, 1.9687576, 4.7162628, 4.5743537, 3.7905611],
  152. vec![3.2932369, 2.8508700, 2.5580937, 2.0437325, 4.2192562],
  153. vec![2.5843321, 2.8329818, 2.1329531, 3.2562319, 2.4878733],
  154. vec![2.1859638, 3.2880048, 3.7018615, 2.3641232, 1.6281994],
  155. vec![2.6201773, 0.9006588, 2.6774097, 1.8188620, 1.6076493],
  156. ];
  157. let clustering = kmeans(xs, 2);
  158. assert_eq!(clustering, vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]);
  159. }
  160. }