跳转至

字典树⚓︎

字典树(\text{Trie}),又称前缀树或单词查找树,是一种用于高效存储和检索字符串集合的数据结构。它的主要特点是通过公共前缀来节省存储空间,并支持快速的字符串查找操作。

Trie
C++
struct Trie {
  Trie() = default;

  struct TrieNode {
    int pass = 0;  // 经过该节点的字符串数量
    int end  = 0;  // 以该节点结尾的字符串数量
    std::unordered_map<char, std::unique_ptr<TrieNode>> children;
  };

  std::unique_ptr<TrieNode> root = std::make_unique<TrieNode>();

  void Insert(const string &word) {
    TrieNode *node = root.get();
    node->pass++;
    for (char ch : word) {
      if (!node->children.contains(ch)) { node->children[ch] = std::make_unique<TrieNode>(); }
      node = node->children[ch].get();
      node->pass++;
    }
    node->end++;
  }

  int CountWordsEqualTo(const string &word) {
    TrieNode *node = root.get();
    for (char ch : word) {
      if (!node->children.contains(ch)) { return 0; }
      node = node->children[ch].get();
    }
    return node->end;
  }

  int CountWordsStartingWith(const string &prefix) {
    TrieNode *node = root.get();
    for (char ch : prefix) {
      if (!node->children.contains(ch)) { return 0; }
      node = node->children[ch].get();
    }
    return node->pass;
  }

  void Erase(const string &word) {
    if (CountWordsEqualTo(word) == 0) { return; }  // 单词不存在,无法删除
    TrieNode *node = root.get();
    node->pass--;
    for (char ch : word) {
      TrieNode *next_node = node->children[ch].get();
      next_node->pass--;
      // 如果经过该节点的字符串数量为0,说明该节点不再需要,删除该节点及其子树
      if (next_node->pass == 0) {
        node->children.erase(ch);
        return;
      }
      node = next_node;
    }
    node->end--;
  }
};
C++
struct Trie {
  const static int max_nodes = 100'000;  // 最大节点数
  // tree[i][j]表示节点i的第j个子节点
  inline static vector<array<int, 26>> tree = vector<array<int, 26>>(max_nodes);
  // 经过该节点的字符串数量
  inline static vector<int> pass = vector<int>(max_nodes);
  // 以该节点结尾的字符串数量
  inline static vector<int> end = vector<int>(max_nodes);

  int count                     = 1;  // 当前节点总数,根节点为1

public:
  Trie() = default;

  void Insert(const string &word) {
    int current = 1;
    pass[current]++;
    for (char ch : word) {
      int index = ch - 'a';
      if (tree[current][index] == 0) { tree[current][index] = ++count; }
      current = tree[current][index];
      pass[current]++;
    }
    end[current]++;
  }

  int CountWordsEqualTo(const string &word) {
    int current = 1;
    for (char ch : word) {
      int index = ch - 'a';
      if (tree[current][index] == 0) { return 0; }
      current = tree[current][index];
    }
    return end[current];
  }

  int CountWordsStartingWith(const string &prefix) {
    int current = 1;
    for (char ch : prefix) {
      int index = ch - 'a';
      if (tree[current][index] == 0) { return 0; }
      current = tree[current][index];
    }
    return pass[current];
  }

  void Erase(const string &word) {
    if (CountWordsEqualTo(word) == 0) { return; }  // 单词不存在,无法删除
    int current = 1;
    pass[current]--;
    for (char ch : word) {
      int index     = ch - 'a';
      int next_node = tree[current][index];
      pass[next_node]--;
      // 如果经过该节点的字符串数量为0,说明该节点不再需要,删除该节点及其子树
      if (pass[next_node] == 0) {
        tree[current][index] = 0;
        return;
      }
      current = next_node;
    }
    end[current]--;
  }

  // 重置Trie树, 每次调用后相当于新建一个Trie树
  void Clear() {
    count = 1;
    std::fill(pass.begin(), pass.end(), 0);
    std::fill(end.begin(), end.end(), 0);
    for (auto &child : tree) { std::fill(child.begin(), child.end(), 0); }
  }
};

处理数字的技巧

将数字转换成字符,然后每个数字结尾加一个特殊字符,例如 '\#' 表示一整个数字的结束。例如数字 -123,转换成字符串 '-123\#',这样就不用增加 tree 第二维的大小,使用 12 个字符('0'-'9', '-', '\#')就能表示所有数字。

数组中两个数的最大异或值

\text{0-1}字典树

C++
    #include <vector>
    using namespace std;

    class Solution {
      struct Trie {
        Trie() = default;

        struct TrieNode {
          TrieNode *left  = nullptr;  // 0
          TrieNode *right = nullptr;  // 1
        };

        const int L    = 30;  // 31位整数,最高位符号位不考虑
        TrieNode *root = new TrieNode();

        void Insert(int num) {
          TrieNode *node = root;
          for (int i = L; i >= 0; i--) {
            int bit = (num >> i) & 1;
            if (bit == 0) {
              if (!node->left) { node->left = new TrieNode(); }
              node = node->left;
            } else {
              if (!node->right) { node->right = new TrieNode(); }
              node = node->right;
            }
          }
        }

        int GetMaxXor(int num) {
          TrieNode *node = root;
          int maxXor     = 0;
          for (int i = L; i >= 0; i--) {
            int bit = (num >> i) & 1;
            if (bit == 0) {
              if (node->right) {  // 有1
                node    = node->right;
                maxXor |= (1 << i);
              } else {  // 没有1
                node = node->left;
              }
            } else {             // bit == 1
              if (node->left) {  // 有0
                node    = node->left;
                maxXor |= (1 << i);
              } else {  // 没有0
                node = node->right;
              }
            }
          }
          return maxXor;
        }
      };

     public:
      int findMaximumXOR(vector<int> &nums) {
        Trie trie;
        for (int num : nums) { trie.Insert(num); }

        int maxXor = 0;
        for (int num : nums) { maxXor = max(maxXor, trie.GetMaxXor(num)); }
        return maxXor;
      }
    };

评论