trie树

  1. use std::collections::HashMap;
  2. use std::hash::Hash;
  3. #[derive(Debug, Default)]
  4. struct Node<Key: Default, Type: Default> {
  5. children: HashMap<Key, Node<Key, Type>>,
  6. value: Option<Type>,
  7. }
  8. #[derive(Debug, Default)]
  9. pub struct Trie<Key, Type>
  10. where
  11. Key: Default + Eq + Hash,
  12. Type: Default,
  13. {
  14. root: Node<Key, Type>,
  15. }
  16. impl<Key, Type> Trie<Key, Type>
  17. where
  18. Key: Default + Eq + Hash,
  19. Type: Default,
  20. {
  21. pub fn new() -> Self {
  22. Self {
  23. root: Node::default(),
  24. }
  25. }
  26. pub fn insert(&mut self, key: impl IntoIterator<Item = Key>, value: Type)
  27. where
  28. Key: Eq + Hash,
  29. {
  30. let mut node = &mut self.root;
  31. for c in key.into_iter() {
  32. node = node.children.entry(c).or_insert_with(Node::default);
  33. }
  34. node.value = Some(value);
  35. }
  36. pub fn get(&self, key: impl IntoIterator<Item = Key>) -> Option<&Type>
  37. where
  38. Key: Eq + Hash,
  39. {
  40. let mut node = &self.root;
  41. for c in key.into_iter() {
  42. if node.children.contains_key(&c) {
  43. node = node.children.get(&c).unwrap()
  44. } else {
  45. return None;
  46. }
  47. }
  48. node.value.as_ref()
  49. }
  50. }
  51. #[cfg(test)]
  52. mod tests {
  53. use super::*;
  54. #[test]
  55. fn test_insertion() {
  56. let mut trie = Trie::new();
  57. assert_eq!(trie.get("".chars()), None);
  58. trie.insert("foo".chars(), 1);
  59. trie.insert("foobar".chars(), 2);
  60. let mut trie = Trie::new();
  61. assert_eq!(trie.get(vec![1, 2, 3]), None);
  62. trie.insert(vec![1, 2, 3], 1);
  63. trie.insert(vec![3, 4, 5], 2);
  64. }
  65. #[test]
  66. fn test_get() {
  67. let mut trie = Trie::new();
  68. trie.insert("foo".chars(), 1);
  69. trie.insert("foobar".chars(), 2);
  70. trie.insert("bar".chars(), 3);
  71. trie.insert("baz".chars(), 4);
  72. assert_eq!(trie.get("foo".chars()), Some(&1));
  73. assert_eq!(trie.get("food".chars()), None);
  74. let mut trie = Trie::new();
  75. trie.insert(vec![1, 2, 3, 4], 1);
  76. trie.insert(vec![42], 2);
  77. trie.insert(vec![42, 6, 1000], 3);
  78. trie.insert(vec![1, 2, 4, 16, 32], 4);
  79. assert_eq!(trie.get(vec![42, 6, 1000]), Some(&3));
  80. assert_eq!(trie.get(vec![43, 44, 45]), None);
  81. }
  82. }