堆(Heap)

  1. // Heap data structure
  2. // Takes a closure as a comparator to allow for min-heap, max-heap, and works with custom key functions
  3. use std::cmp::Ord;
  4. use std::default::Default;
  5. pub struct Heap<T>
  6. where
  7. T: Default,
  8. {
  9. count: usize,
  10. items: Vec<T>,
  11. comparator: fn(&T, &T) -> bool,
  12. }
  13. impl<T> Heap<T>
  14. where
  15. T: Default,
  16. {
  17. pub fn new(comparator: fn(&T, &T) -> bool) -> Self {
  18. Self {
  19. count: 0,
  20. // Add a default in the first spot to offset indexes
  21. // for the parent/child math to work out.
  22. // Vecs have to have all the same type so using Default
  23. // is a way to add an unused item.
  24. items: vec![T::default()],
  25. comparator,
  26. }
  27. }
  28. pub fn len(&self) -> usize {
  29. self.count
  30. }
  31. pub fn is_empty(&self) -> bool {
  32. self.len() == 0
  33. }
  34. pub fn add(&mut self, value: T) {
  35. self.count += 1;
  36. self.items.push(value);
  37. // Heapify Up
  38. let mut idx = self.count;
  39. while self.parent_idx(idx) > 0 {
  40. let pdx = self.parent_idx(idx);
  41. if (self.comparator)(&self.items[idx], &self.items[pdx]) {
  42. self.items.swap(idx, pdx);
  43. }
  44. idx = pdx;
  45. }
  46. }
  47. fn parent_idx(&self, idx: usize) -> usize {
  48. idx / 2
  49. }
  50. fn children_present(&self, idx: usize) -> bool {
  51. self.left_child_idx(idx) <= self.count
  52. }
  53. fn left_child_idx(&self, idx: usize) -> usize {
  54. idx * 2
  55. }
  56. fn right_child_idx(&self, idx: usize) -> usize {
  57. self.left_child_idx(idx) + 1
  58. }
  59. fn smallest_child_idx(&self, idx: usize) -> usize {
  60. if self.right_child_idx(idx) > self.count {
  61. self.left_child_idx(idx)
  62. } else {
  63. let ldx = self.left_child_idx(idx);
  64. let rdx = self.right_child_idx(idx);
  65. if (self.comparator)(&self.items[ldx], &self.items[rdx]) {
  66. ldx
  67. } else {
  68. rdx
  69. }
  70. }
  71. }
  72. }
  73. impl<T> Heap<T>
  74. where
  75. T: Default + Ord,
  76. {
  77. /// Create a new MinHeap
  78. pub fn new_min() -> Self {
  79. Self::new(|a, b| a < b)
  80. }
  81. /// Create a new MaxHeap
  82. pub fn new_max() -> Self {
  83. Self::new(|a, b| a > b)
  84. }
  85. }
  86. impl<T> Iterator for Heap<T>
  87. where
  88. T: Default,
  89. {
  90. type Item = T;
  91. fn next(&mut self) -> Option<T> {
  92. if self.count == 0 {
  93. return None;
  94. }
  95. // This feels like a function built for heap impl :)
  96. // Removes an item at an index and fills in with the last item
  97. // of the Vec
  98. let next = Some(self.items.swap_remove(1));
  99. self.count -= 1;
  100. if self.count > 0 {
  101. // Heapify Down
  102. let mut idx = 1;
  103. while self.children_present(idx) {
  104. let cdx = self.smallest_child_idx(idx);
  105. if !(self.comparator)(&self.items[idx], &self.items[cdx]) {
  106. self.items.swap(idx, cdx);
  107. }
  108. idx = cdx;
  109. }
  110. }
  111. next
  112. }
  113. }
  114. pub struct MinHeap;
  115. impl MinHeap {
  116. #[allow(clippy::new_ret_no_self)]
  117. pub fn new<T>() -> Heap<T>
  118. where
  119. T: Default + Ord,
  120. {
  121. Heap::new(|a, b| a < b)
  122. }
  123. }
  124. pub struct MaxHeap;
  125. impl MaxHeap {
  126. #[allow(clippy::new_ret_no_self)]
  127. pub fn new<T>() -> Heap<T>
  128. where
  129. T: Default + Ord,
  130. {
  131. Heap::new(|a, b| a > b)
  132. }
  133. }
  134. #[cfg(test)]
  135. mod tests {
  136. use super::*;
  137. #[test]
  138. fn test_empty_heap() {
  139. let mut heap = MaxHeap::new::<i32>();
  140. assert_eq!(heap.next(), None);
  141. }
  142. #[test]
  143. fn test_min_heap() {
  144. let mut heap = MinHeap::new();
  145. heap.add(4);
  146. heap.add(2);
  147. heap.add(9);
  148. heap.add(11);
  149. assert_eq!(heap.len(), 4);
  150. assert_eq!(heap.next(), Some(2));
  151. assert_eq!(heap.next(), Some(4));
  152. assert_eq!(heap.next(), Some(9));
  153. heap.add(1);
  154. assert_eq!(heap.next(), Some(1));
  155. }
  156. #[test]
  157. fn test_max_heap() {
  158. let mut heap = MaxHeap::new();
  159. heap.add(4);
  160. heap.add(2);
  161. heap.add(9);
  162. heap.add(11);
  163. assert_eq!(heap.len(), 4);
  164. assert_eq!(heap.next(), Some(11));
  165. assert_eq!(heap.next(), Some(9));
  166. assert_eq!(heap.next(), Some(4));
  167. heap.add(1);
  168. assert_eq!(heap.next(), Some(2));
  169. }
  170. struct Point(/* x */ i32, /* y */ i32);
  171. impl Default for Point {
  172. fn default() -> Self {
  173. Self(0, 0)
  174. }
  175. }
  176. #[test]
  177. fn test_key_heap() {
  178. let mut heap: Heap<Point> = Heap::new(|a, b| a.0 < b.0);
  179. heap.add(Point(1, 5));
  180. heap.add(Point(3, 10));
  181. heap.add(Point(-2, 4));
  182. assert_eq!(heap.len(), 3);
  183. assert_eq!(heap.next().unwrap().0, -2);
  184. assert_eq!(heap.next().unwrap().0, 1);
  185. heap.add(Point(50, 34));
  186. assert_eq!(heap.next().unwrap().0, 3);
  187. }
  188. }