BST、AVL、RbTree、SegmentTree、FenwickTree

tree,tree,tree

#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>

using namespace std;

const int INF = INT_MAX; // 定义无穷大

// ===================== Binary Search Tree (BST) =====================
template <typename T>
class BST {
private:
    struct Node {
        T key;
        Node* left;
        Node* right;
        Node(T k) : key(k), left(nullptr), right(nullptr) {}
    };

    Node* root;

    Node* insert(Node* node, T key) {
        if (!node) return new Node(key);
        if (key < node->key) {
            node->left = insert(node->left, key);
        } else if (key > node->key) {
            node->right = insert(node->right, key);
        }
        return node;
    }

    Node* deleteNode(Node* node, T key) {
        if (!node) return nullptr;
        if (key < node->key) {
            node->left = deleteNode(node->left, key);
        } else if (key > node->key) {
            node->right = deleteNode(node->right, key);
        } else {
            if (!node->left) {
                Node* temp = node->right;
                delete node;
                return temp;
            } else if (!node->right) {
                Node* temp = node->left;
                delete node;
                return temp;
            }
            Node* temp = minValueNode(node->right);
            node->key = temp->key;
            node->right = deleteNode(node->right, temp->key);
        }
        return node;
    }

    Node* minValueNode(Node* node) {
        Node* current = node;
        while (current && current->left) {
            current = current->left;
        }
        return current;
    }

    T predecessor(Node* node, T key) {
        if (!node) return -INF;
        if (key <= node->key) {
            return predecessor(node->left, key);
        } else {
            T pred = predecessor(node->right, key);
            return (pred == -INF) ? node->key : pred;
        }
    }

    T successor(Node* node, T key) {
        if (!node) return INF;
        if (key >= node->key) {
            return successor(node->right, key);
        } else {
            T succ = successor(node->left, key);
            return (succ == INF) ? node->key : succ;
        }
    }

    int rank(Node* node, T key) {
        if (!node) return 0;
        if (key < node->key) {
            return rank(node->left, key);
        } else if (key > node->key) {
            return 1 + size(node->left) + rank(node->right, key);
        } else {
            return size(node->left);
        }
    }

    T value(Node* node, int rank) {
        if (!node) return -1;
        int leftSize = size(node->left);
        if (rank < leftSize) {
            return value(node->left, rank);
        } else if (rank > leftSize) {
            return value(node->right, rank - leftSize - 1);
        } else {
            return node->key;
        }
    }

    int size(Node* node) {
        if (!node) return 0;
        return 1 + size(node->left) + size(node->right);
    }

public:
    BST() : root(nullptr) {}

    void insert(T key) {
        root = insert(root, key);
    }

    void deleteNode(T key) {
        root = deleteNode(root, key);
    }

    T predecessor(T key) {
        return predecessor(root, key);
    }

    T successor(T key) {
        return successor(root, key);
    }

    int rank(T key) {
        return rank(root, key);
    }

    T value(int rank) {
        return value(root, rank);
    }
};

// ===================== AVL Tree =====================
template <typename T>
class AVL {
private:
    struct Node {
        T key;
        Node* left;
        Node* right;
        int height;
        Node(T k) : key(k), left(nullptr), right(nullptr), height(1) {}
    };

    Node* root;

    int height(Node* node) {
        return node ? node->height : 0;
    }

    int balanceFactor(Node* node) {
        return node ? height(node->left) - height(node->right) : 0;
    }

    void updateHeight(Node* node) {
        if (node) {
            node->height = 1 + max(height(node->left), height(node->right));
        }
    }

    Node* rightRotate(Node* y) {
        Node* x = y->left;
        Node* T2 = x->right;
        x->right = y;
        y->left = T2;
        updateHeight(y);
        updateHeight(x);
        return x;
    }

    Node* leftRotate(Node* x) {
        Node* y = x->right;
        Node* T2 = y->left;
        y->left = x;
        x->right = T2;
        updateHeight(x);
        updateHeight(y);
        return y;
    }

    Node* insert(Node* node, T key) {
        if (!node) return new Node(key);
        if (key < node->key) {
            node->left = insert(node->left, key);
        } else if (key > node->key) {
            node->right = insert(node->right, key);
        } else {
            return node; // Duplicate keys not allowed
        }

        updateHeight(node);
        int balance = balanceFactor(node);

        // Left Left Case
        if (balance > 1 && key < node->left->key) {
            return rightRotate(node);
        }
        // Right Right Case
        if (balance < -1 && key > node->right->key) {
            return leftRotate(node);
        }
        // Left Right Case
        if (balance > 1 && key > node->left->key) {
            node->left = leftRotate(node->left);
            return rightRotate(node);
        }
        // Right Left Case
        if (balance < -1 && key < node->right->key) {
            node->right = rightRotate(node->right);
            return leftRotate(node);
        }

        return node;
    }

    Node* minValueNode(Node* node) {
        Node* current = node;
        while (current && current->left) {
            current = current->left;
        }
        return current;
    }

    Node* deleteNode(Node* node, T key) {
        if (!node) return nullptr;
        if (key < node->key) {
            node->left = deleteNode(node->left, key);
        } else if (key > node->key) {
            node->right = deleteNode(node->right, key);
        } else {
            if (!node->left || !node->right) {
                Node* temp = node->left ? node->left : node->right;
                if (!temp) {
                    temp = node;
                    node = nullptr;
                } else {
                    *node = *temp;
                }
                delete temp;
            } else {
                Node* temp = minValueNode(node->right);
                node->key = temp->key;
                node->right = deleteNode(node->right, temp->key);
            }
        }

        if (!node) return nullptr;

        updateHeight(node);
        int balance = balanceFactor(node);

        // Left Left Case
        if (balance > 1 && balanceFactor(node->left) >= 0) {
            return rightRotate(node);
        }
        // Left Right Case
        if (balance > 1 && balanceFactor(node->left) < 0) {
            node->left = leftRotate(node->left);
            return rightRotate(node);
        }
        // Right Right Case
        if (balance < -1 && balanceFactor(node->right) <= 0) {
            return leftRotate(node);
        }
        // Right Left Case
        if (balance < -1 && balanceFactor(node->right) > 0) {
            node->right = rightRotate(node->right);
            return leftRotate(node);
        }

        return node;
    }

    T predecessor(Node* node, T key) {
        if (!node) return -INF;
        if (key <= node->key) {
            return predecessor(node->left, key);
        } else {
            T pred = predecessor(node->right, key);
            return (pred == -INF) ? node->key : pred;
        }
    }

    T successor(Node* node, T key) {
        if (!node) return INF;
        if (key >= node->key) {
            return successor(node->right, key);
        } else {
            T succ = successor(node->left, key);
            return (succ == INF) ? node->key : succ;
        }
    }

    int rank(Node* node, T key) {
        if (!node) return 0;
        if (key < node->key) {
            return rank(node->left, key);
        } else if (key > node->key) {
            return 1 + size(node->left) + rank(node->right, key);
        } else {
            return size(node->left);
        }
    }

    T value(Node* node, int rank) {
        if (!node) return -1;
        int leftSize = size(node->left);
        if (rank < leftSize) {
            return value(node->left, rank);
        } else if (rank > leftSize) {
            return value(node->right, rank - leftSize - 1);
        } else {
            return node->key;
        }
    }

    int size(Node* node) {
        if (!node) return 0;
        return 1 + size(node->left) + size(node->right);
    }

public:
    AVL() : root(nullptr) {}

    void insert(T key) {
        root = insert(root, key);
    }

    void deleteNode(T key) {
        root = deleteNode(root, key);
    }

    T predecessor(T key) {
        return predecessor(root, key);
    }

    T successor(T key) {
        return successor(root, key);
    }

    int rank(T key) {
        return rank(root, key);
    }

    T value(int rank) {
        return value(root, rank);
    }
};

// ===================== Red-Black Tree =====================
template <typename T>
class RBTree {
private:
    enum Color { RED, BLACK };

    struct Node {
        T key;
        Node* left;
        Node* right;
        Node* parent;
        Color color;
        Node(T k, Color c = RED, Node* p = nullptr, Node* l = nullptr, Node* r = nullptr)
            : key(k), color(c), parent(p), left(l), right(r) {}
    };

    Node* root;
    Node* nil; // Sentinel node代表空节点

    void leftRotate(Node* x) {
        Node* y = x->right;
        x->right = y->left;
        if (y->left != nil) {
            y->left->parent = x;
        }
        y->parent = x->parent;
        if (x->parent == nil) {
            root = y;
        } else if (x == x->parent->left) {
            x->parent->left = y;
        } else {
            x->parent->right = y;
        }
        y->left = x;
        x->parent = y;
    }

    void rightRotate(Node* y) {
        Node* x = y->left;
        y->left = x->right;
        if (x->right != nil) {
            x->right->parent = y;
        }
        x->parent = y->parent;
        if (y->parent == nil) {
            root = x;
        } else if (y == y->parent->right) {
            y->parent->right = x;
        } else {
            y->parent->left = x;
        }
        x->right = y;
        y->parent = x;
    }

    void insertFixup(Node* z) {
        while (z->parent->color == RED) {
            if (z->parent == z->parent->parent->left) {
                Node* y = z->parent->parent->right;
                if (y->color == RED) {
                    z->parent->color = BLACK;
                    y->color = BLACK;
                    z->parent->parent->color = RED;
                    z = z->parent->parent;
                } else {
                    if (z == z->parent->right) {
                        z = z->parent;
                        leftRotate(z);
                    }
                    z->parent->color = BLACK;
                    z->parent->parent->color = RED;
                    rightRotate(z->parent->parent);
                }
            } else {
                Node* y = z->parent->parent->left;
                if (y->color == RED) {
                    z->parent->color = BLACK;
                    y->color = BLACK;
                    z->parent->parent->color = RED;
                    z = z->parent->parent;
                } else {
                    if (z == z->parent->left) {
                        z = z->parent;
                        rightRotate(z);
                    }
                    z->parent->color = BLACK;
                    z->parent->parent->color = RED;
                    leftRotate(z->parent->parent);
                }
            }
        }
        root->color = BLACK;
    }

    void transplant(Node* u, Node* v) {
        if (u->parent == nil) {
            root = v;
        } else if (u == u->parent->left) {
            u->parent->left = v;
        } else {
            u->parent->right = v;
        }
        v->parent = u->parent;
    }

    Node* minimum(Node* node) {
        while (node->left != nil) {
            node = node->left;
        }
        return node;
    }

    void deleteFixup(Node* x) {
        while (x != root && x->color == BLACK) {
            if (x == x->parent->left) {
                Node* w = x->parent->right;
                if (w->color == RED) {
                    w->color = BLACK;
                    x->parent->color = RED;
                    leftRotate(x->parent);
                    w = x->parent->right;
                }
                if (w->left->color == BLACK && w->right->color == BLACK) {
                    w->color = RED;
                    x = x->parent;
                } else {
                    if (w->right->color == BLACK) {
                        w->left->color = BLACK;
                        w->color = RED;
                        rightRotate(w);
                        w = x->parent->right;
                    }
                    w->color = x->parent->color;
                    x->parent->color = BLACK;
                    w->right->color = BLACK;
                    leftRotate(x->parent);
                    x = root;
                }
            } else {
                Node* w = x->parent->left;
                if (w->color == RED) {
                    w->color = BLACK;
                    x->parent->color = RED;
                    rightRotate(x->parent);
                    w = x->parent->left;
                }
                if (w->right->color == BLACK && w->left->color == BLACK) {
                    w->color = RED;
                    x = x->parent;
                } else {
                    if (w->left->color == BLACK) {
                        w->right->color = BLACK;
                        w->color = RED;
                        leftRotate(w);
                        w = x->parent->left;
                    }
                    w->color = x->parent->color;
                    x->parent->color = BLACK;
                    w->left->color = BLACK;
                    rightRotate(x->parent);
                    x = root;
                }
            }
        }
        x->color = BLACK;
    }

    void deleteNode(Node* z) {
        Node* y = z;
        Node* x;
        Color yOriginalColor = y->color;
        if (z->left == nil) {
            x = z->right;
            transplant(z, z->right);
        } else if (z->right == nil) {
            x = z->left;
            transplant(z, z->left);
        } else {
            y = minimum(z->right);
            yOriginalColor = y->color;
            x = y->right;
            if (y->parent == z) {
                x->parent = y;
            } else {
                transplant(y, y->right);
                y->right = z->right;
                y->right->parent = y;
            }
            transplant(z, y);
            y->left = z->left;
            y->left->parent = y;
            y->color = z->color;
        }
        if (yOriginalColor == BLACK) {
            deleteFixup(x);
        }
        delete z;
    }

    Node* search(Node* node, T key) {
        if (node == nil || key == node->key) {
            return node;
        }
        if (key < node->key) {
            return search(node->left, key);
        } else {
            return search(node->right, key);
        }
    }

    Node* predecessor(Node* node) {
        if (node->left != nil) {
            return maximum(node->left);
        }
        Node* y = node->parent;
        while (y != nil && node == y->left) {
            node = y;
            y = y->parent;
        }
        return y;
    }

    Node* successor(Node* node) {
        if (node->right != nil) {
            return minimum(node->right);
        }
        Node* y = node->parent;
        while (y != nil && node == y->right) {
            node = y;
            y = y->parent;
        }
        return y;
    }

    Node* maximum(Node* node) {
        while (node->right != nil) {
            node = node->right;
        }
        return node;
    }

    int rank(Node* node, T key) {
        if (node == nil) return 0;
        if (key < node->key) {
            return rank(node->left, key);
        } else if (key > node->key) {
            return 1 + size(node->left) + rank(node->right, key);
        } else {
            return size(node->left);
        }
    }

    T value(Node* node, int rank) {
        if (node == nil) return -1;
        int leftSize = size(node->left);
        if (rank < leftSize) {
            return value(node->left, rank);
        } else if (rank > leftSize) {
            return value(node->right, rank - leftSize - 1);
        } else {
            return node->key;
        }
    }

    int size(Node* node) {
        if (node == nil) return 0;
        return 1 + size(node->left) + size(node->right);
    }

public:
    RBTree() {
        nil = new Node(0, BLACK);
        root = nil;
    }

    void insert(T key) {
        Node* z = new Node(key, RED, nil, nil, nil);
        Node* y = nil;
        Node* x = root;
        while (x != nil) {
            y = x;
            if (z->key < x->key) {
                x = x->left;
            } else {
                x = x->right;
            }
        }
        z->parent = y;
        if (y == nil) {
            root = z;
        } else if (z->key < y->key) {
            y->left = z;
        } else {
            y->right = z;
        }
        insertFixup(z);
    }

    void deleteNode(T key) {
        Node* z = search(root, key);
        if (z != nil) {
            deleteNode(z);
        }
    }

    T predecessor(T key) {
        Node* node = search(root, key);
        if (node == nil) return -INF;
        Node* pred = predecessor(node);
        return pred == nil ? -INF : pred->key;
    }

    T successor(T key) {
        Node* node = search(root, key);
        if (node == nil) return INF;
        Node* succ = successor(node);
        return succ == nil ? INF : succ->key;
    }

    int rank(T key) {
        return rank(root, key);
    }

    T value(int rank) {
        return value(root, rank);
    }
};

// ===================== Fenwick Tree (TreeArray) =====================
template <typename T>
class FenwickTree {
private:
    vector<T> tree;
    int n;

    int lowbit(int x) {
        return x & -x;
    }

public:
    FenwickTree(int size) : n(size), tree(size + 1, 0) {}

    void update(int idx, T delta) {
        while (idx <= n) {
            tree[idx] += delta;
            idx += lowbit(idx);
        }
    }

    T query(int idx) {
        T sum = 0;
        while (idx > 0) {
            sum += tree[idx];
            idx -= lowbit(idx);
        }
        return sum;
    }

    T rangeQuery(int l, int r) {
        return query(r) - query(l - 1);
    }

    void add(int x, int y, T z) {
        for (int i = x; i <= y; i++) {
            update(i, z);
        }
    }

    void add(int x, T y) {
        update(x, y);
    }

    T getsum(int x, int y) {
        return rangeQuery(x, y);
    }
};

// ===================== Segment Tree =====================
template <typename T>
class SegmentTree {
private:
    vector<T> tree;
    vector<T> lazy;
    int n;

    void build(vector<T>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
        } else {
            int mid = (start + end) / 2;
            build(arr, 2 * node + 1, start, mid);
            build(arr, 2 * node + 2, mid + 1, end);
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
        }
    }

    void pushDown(int node, int start, int end) {
        if (lazy[node] != 0) {
            int mid = (start + end) / 2;
            tree[2 * node + 1] += lazy[node] * (mid - start + 1);
            lazy[2 * node + 1] += lazy[node];
            tree[2 * node + 2] += lazy[node] * (end - mid);
            lazy[2 * node + 2] += lazy[node];
            lazy[node] = 0;
        }
    }

    T query(int node, int start, int end, int l, int r) {
        if (r < start || l > end) return 0;
        if (l <= start && end <= r) return tree[node];
        pushDown(node, start, end);
        int mid = (start + end) / 2;
        return query(2 * node + 1, start, mid, l, r) + query(2 * node + 2, mid + 1, end, l, r);
    }

    void updateRange(int node, int start, int end, int l, int r, T val) {
        if (r < start || l > end) return;
        if (l <= start && end <= r) {
            tree[node] += val * (end - start + 1);
            lazy[node] += val;
            return;
        }
        pushDown(node, start, end);
        int mid = (start + end) / 2;
        updateRange(2 * node + 1, start, mid, l, r, val);
        updateRange(2 * node + 2, mid + 1, end, l, r, val);
        tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
    }

    void updatePoint(int node, int start, int end, int idx, T val) {
        if (start == end) {
            tree[node] += val;
        } else {
            int mid = (start + end) / 2;
            if (idx <= mid) {
                updatePoint(2 * node + 1, start, mid, idx, val);
            } else {
                updatePoint(2 * node + 2, mid + 1, end, idx, val);
            }
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];
        }
    }

public:
    SegmentTree(vector<T>& arr) {
        n = arr.size();
        tree.resize(4 * n);
        lazy.resize(4 * n, 0);
        build(arr, 0, 0, n - 1);
    }

    void add(int x, int y, T z) {
        updateRange(0, 0, n - 1, x, y, z);
    }

    void add(int x, T y) {
        updatePoint(0, 0, n - 1, x, y);
    }

    T getsum(int x, int y) {
        return query(0, 0, n - 1, x, y);
    }
};

// ===================== Main Function =====================
int main() {
    // BST Example
    BST<int> b;
    b.insert(10);
    b.insert(20);
    b.insert(30);
    cout << "BST Predecessor of 25: " << b.predecessor(25) << endl;
    cout << "BST Successor of 25: " << b.successor(25) << endl;

    // AVL Example
    AVL<int> a;
    a.insert(10);
    a.insert(20);
    a.insert(30);
    cout << "AVL Predecessor of 25: " << a.predecessor(25) << endl;
    cout << "AVL Successor of 25: " << a.successor(25) << endl;

    // Fenwick Tree Example
    FenwickTree<int> f(10);
    f.add(1, 5, 2);
    cout << "Fenwick Tree Sum (1, 5): " << f.getsum(1, 5) << endl;

    // Segment Tree Example
    vector<int> arr = {1, 3, 5, 7, 9, 11};
    SegmentTree<int> S(arr);
    S.add(1, 3, 2);
    cout << "Segment Tree Sum (1, 3): " << S.getsum(1, 3) << endl;

	// Red-Black Tree Example
    RBTree<int> r;
    r.insert(10);
    r.insert(20);
    r.insert(30);
    cout << "RBTree Predecessor of 25: " << r.predecessor(25) << endl;
    cout << "RBTree Successor of 25: " << r.successor(25) << endl;
    cout << "RBTree Rank of 20: " << r.rank(20) << endl;
    cout << "RBTree Value at rank 1: " << r.value(1) << endl;

    return 0;
}