「Bzoj 3224」普通平衡树 (非旋Treap)

BZOJ 3224
平衡树模板题,注意相同元素的处理。

2017-8-22:
非旋转Treap:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 100000 + 5;
struct Treap *null, *pit, *root;//自建的null节点,内存池指针 
struct Treap {
    int key, val, s;//键值(满足排序二叉树性质),附加值(满足堆性质),以本节点为根的子树大小
    Treap *lc, *rc;//Treap 左右结点 
    Treap() {};
    void init(int key) {this->key = key, val = rand(), lc = rc = null, s = 1;}//初始化 
    void maintain() {s = lc->s + rc->s + 1;}//更新,维护以本节点为根的子树大小
}pool[MAXN];//内存池 
Treap* newNode(int key) {
    pit->init(key);
    return pit++;
}
Treap* merge(Treap *a, Treap *b) {//合并两棵Treap,所有key(a)<key(b)才能合并
    if (a == null) return b;
    if (b == null) return a;
    if (a->val < b->val) {
        a->rc = merge(a->rc, b), a->maintain();
        return a;
    }
    if (a->val >= b->val) {
        b->lc = merge(a, b->lc), b->maintain();
        return b;
    }
}
void split_v(Treap *o, int v, Treap *&x, Treap *&y) {//用权值分开一棵Treap,分开的第一棵根为x,第二棵根为y
    if (o == null) x = y = null; else {
        if (o->key <= v) {
            x = o, split_v(o->rc, v, o->rc, y);
        } else y = o, split_v(o->lc, v, x, o->lc);
        o->maintain();
    }
}
void split_k(Treap *o, int k, Treap *&x, Treap *&y) {//按前k个分配分开一棵Treap,分开的第一棵根为x,第二棵根为y
    if (o == null) x = y = null; else {
        if (k <= o->lc->s) {
            y = o, split_k(o->lc, k, x, o->lc);
        } else x = o, split_k(o->rc, k - o->lc->s - 1, o->rc, y);
        o->maintain();
    }
}
void insert(int v) {//插入一个数 
    Treap *a, *b;
    split_v(root, v, a, b);
    root = merge(merge(a, newNode(v)), b);
}
void del(int v) {//删除一个数 
    Treap *a, *b;
    split_v(root, v, a, b);
    Treap *c, *d;
    split_v(a, v - 1, c, d);
    d = merge(d->lc, d->rc);
    a = merge(c, d), root = merge(a, b);
}
int kth(int k) {//查询排名为k的数
    Treap *a, *b;
    split_k(root, k - 1, a, b);
    Treap *c, *d;
    split_k(b, 1, c, d);
    int ret = c->key;
    b = merge(c, d), root = merge(a, b);
    return ret;
}
int rk(int x) {//查询x的排名
    Treap *a, *b;
    split_v(root, x - 1, a, b);
    int ret = a->s + 1;
    root = merge(a, b);
    return ret;
}
int pre(int x) {//求x的前驱 
    Treap *a, *b;
    split_v(root, x - 1, a, b);
    Treap *c, *d;
    split_k(a, a->s - 1, c, d);
    int ret = d->key;
    a = merge(c, d), root = merge(a, b); 
    return ret;
}
int succ(int x) {//求x的后继 
    Treap *a, *b;
    split_v(root, x, a, b);
    Treap *c, *d;
    split_k(b, 1, c, d);
    int ret = c->key;
    b = merge(c, d), root = merge(a, b); 
    return ret;
}
void initTreap() {
     srand(19260827);//置随机数种子 
     pit = pool;//指针指向内存池 
     null = newNode(0), null->s = 0;//初始化自建的null节点
     root = null;//初始化树根 
}
void clean() {}
void solve() {
    clean();
    initTreap();
    int Q;
    scanf("%d", &Q);
    while (Q--) {
        int opt, x;
        scanf("%d%d", &opt, &x);
        switch (opt) {
            case 1: insert(x); break;
            case 2: del(x); break;
            case 3: printf("%d\n", rk(x)); break;
            case 4: printf("%d\n", kth(x)); break;
            case 5: printf("%d\n", pre(x)); break;
            case 6: printf("%d\n", succ(x)); break;
        }
    }
}
int main() {
    solve();
    return 0;
}

Splay :

#include<cstdio> 
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#include<vector>
#include<utility>
#include<cmath>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
#define fir first
#define sec second
#define mp make_pair
using namespace std;

namespace flyinthesky {

    const int MAXN = 200005, INF = 2000000005;

    int ch[MAXN][2], fa[MAXN], val[MAXN], cnt[MAXN], siz[MAXN], ncnt, rt, n;
    // splay 数组,父亲节点,节点值,节点值个数,节点子树大小,节点总数,根 

    bool rel(int x) {return ch[fa[x]][1] == x;} // x 点在父亲点的位置 (x 点是不是父亲点的右孩子) 
    void pushup(int x) {siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];}
    void rotate(int x) { // 旋转操作 
        int y = fa[x], z = fa[y], k = rel(x), w = ch[x][k ^ 1];
        ch[z][rel(y)] = x, fa[x] = z; // 将爷爷节点与自己相连 
        ch[y][k] = w, fa[w] = y; // 将自己的孩子给父节点
        ch[x][k ^ 1] = y, fa[y] = x; // 将父节点接到自己孩子
        pushup(y), pushup(x); // 破坏了结构记得 pushup 
    }
    void splay(int x, int gl = 0) { // splay 操作 -> 将 x 节点转到 gl 节点孩子 ,并且可以用作更新信息 
        while (fa[x] != gl) { // 一直旋转到目标节点的儿子
            int y = fa[x], z = fa[y];
            if (z != gl) { // 爷爷节点不是目标节点,否则直接旋转 
                if (rel(x) == rel(y)) rotate(y); else rotate(x); // 自身、父亲、爷爷三点一线,则先旋转父亲节点 
            }
            rotate(x); // 旋转自身 
        }
        if (!gl) rt = x; // 如果旋转到根则更新 rt 
    }
    void insert(int x) { // 插入一个值为 x 的节点并转到根 
        int cur = rt, p = 0; // cur 当前节点, p 为 cur 父亲 
        while (cur && val[cur] != x) p = cur, cur = ch[cur][x > val[cur]]; // 找插入位置 
        if (cur) cnt[cur]++; else { // 找到插入位置,如果已经有数则 cnt 加一即可,否则新建节点 
            cur = ++ncnt;
            ch[p][x > val[p]] = cur;
            ch[cur][0] = ch[cur][1] = 0;
            fa[cur] = p; val[cur] = x;
            cnt[cur] = siz[cur] = 1; 
        }
        splay(cur); // splay 到根,以免加入节点后拉出一条链 + 更新信息 
    }
    void find(int x) { // 找一个值为 x 的节点转到根,不存在则返回当前数的前驱或后继 
        int cur = rt;
        while (ch[cur][x > val[cur]] && x != val[cur]) cur = ch[cur][x > val[cur]];
        splay(cur);  
    }
    int kth(int k) { // 找值第 k 大数 
        int cur = rt;
        while (1) {
            if (k <= siz[ch[cur][0]]) cur = ch[cur][0]; 
            else if (k > siz[ch[cur][0]] + cnt[cur]) k -= siz[ch[cur][0]] + cnt[cur], cur = ch[cur][1];
            else return cur;
        }
    }
    int rnk(int x) { // 值 x 的排名 
        find(x); // splay 值为 x 的节点到根后左子树的大小即为值 x 排名 
        return siz[ch[rt][0]]; 
    }
    int pre(int x) { // 值前驱 
        find(x); // 先将值为 x 的节点 splay 到根 
        if (val[rt] < x) return rt; // 找不到这样的节点则直接返回根 
        int cur = ch[rt][0]; // 找根左子树最右边的值即为前驱 
        while (ch[cur][1]) cur = ch[cur][1];
        return cur;
    }
    int succ(int x) { // 值后继 
        find(x); // 先将值为 x 的节点 splay 到根  
        if (val[rt] > x) return rt; // 找不到这样的节点则直接返回根 
        int cur = ch[rt][1]; // 找根右子树最左边的值即为前驱 
        while (ch[cur][0]) cur = ch[cur][0];
        return cur;
    }
    void remove(int x) { // 删除一个值为 x 的节点 
        int last = pre(x), next = succ(x);
        splay(last), splay(next, last); // 前驱转到根,后继转到根的右孩子 
        int del = ch[next][0];
        if (cnt[del] > 1) --cnt[del], splay(del); else ch[next][0] = 0, pushup(next), pushup(last);
        // 记得信息的更新上传 
    }

    void clean() {
        ms(fa, 0), ms(cnt, 0), ms(val, 0), ms(siz, 0), rt = ncnt = 0;
    }
    int solve() {
        clean();
        scanf("%d", &n);
        insert(INF), insert(-INF);  // 插入最大最小值 
        while (n--) {
            int opt, x;
            scanf("%d%d", &opt, &x);
            switch (opt) {
                case 1: insert(x); break;
                case 2: remove(x); break;
                case 3: printf("%d\n", rnk(x)); break;
                case 4: printf("%d\n", val[kth(x + 1)]); break;
                case 5: printf("%d\n", val[pre(x)]); break;
                case 6: printf("%d\n", val[succ(x)]); break;
            }
        }
        return 0; 
    }
}
int main() {
    flyinthesky::solve();
    return 0;
}

旋转Treap:

#include<cstdio>    
#include<algorithm>    
#include<cstring>     
#include<vector>   
#define ms(i,j) memset(i,j, sizeof i);    
using namespace std;  
struct node
{
    node *ch[2];//左右孩子 
    int v, r;//值,优先级 
    int s;//附加值:以当前节点为根的结点数量
    int w;//附加值:和当前节点相同值的结点数 
    void mt()
    {
        s = w;
        if (ch[0]!=NULL) s += ch[0]->s;
        if (ch[1]!=NULL) s += ch[1]->s;
    } 
};
int ans;
struct treap
{
    node *root;
    void rotate(int d, node *&o)//d=0则左旋 d=1则右旋 
    {
        node *k = o->ch[d^1]; o->ch[d^1] = k->ch[d]; 
        k->ch[d] = o; o->mt(); k->mt(); o = k;
    }
    void insert(int x, node *&o)//插入一个数 
    {
        if (o==NULL) 
        {
            o = new node(); o->v = x; o->r = rand(); o->s = o->w = 1; o->ch[0] = o->ch[1] = NULL; //初值 
        } else     if (o->v==x) o->w++; //有相同直接w++
        else
        {
            int d = (x < o->v ? 0 : 1);
            insert(x, o->ch[d]);
            if (o->ch[d]->r > o->r) rotate(d^1, o);
        }
        o->mt();
    }
    void del(int x, node *&o)//删除一个数 
    {
        int d = (x < o->v ? 0 : 1);
        if (o->v==x)//找到 
        {
            if (o->w>1) {o->w--; o->s--;} else//不止一个数就直接w--,s-- 
            if (o->ch[0]==NULL) o = o->ch[1];
            else if (o->ch[1]==NULL) o = o->ch[0];
            else {
                int d2 = (o->ch[0]->r > o->ch[1]->r ? 1 : 0);
                rotate(d2,o); del(x, o->ch[d2]);
            }
        } else
        {
            del(x, o->ch[d]);
        }
        if (o!=NULL) o->mt();
    }
    int rank(int x, node *o)//求x的排名 
    {
        int tmp;
        if (o->ch[0]==NULL) tmp = 0;
        else  tmp=o->ch[0]->s;//求s 

        if (o->v==x) return tmp+1;//找到了 
        else if (o->v >x) return rank(x,o->ch[0]);
        else return tmp+o->w+rank(x,o->ch[1]);
    }
    int kth(int k, node *o)//求第k小 
    {
        if (o==NULL||o->s<k||k<=0) return -1;//不符合要求 
        int tmp;
        if (o->ch[0]==NULL) tmp = 0;
        else tmp = o->ch[0]->s;//求s

        if (k<=tmp) return kth(k,o->ch[0]);
        else if (k > tmp+o->w) return kth(k - tmp - o->w,o->ch[1]);
        else return o->v;//找到了 
    }
    void pred(int x, node *o)//求前驱 
    {
        if(o==NULL)return;
        if(o->v<x) {
            ans = o->v;
            pred(x, o->ch[1]);
        } else pred(x, o->ch[0]);
    }
    void succ(int x, node *o)//求后继 
    {
        if(o==NULL)return;
        if(o->v>x) {
            ans = o->v;
            succ(x, o->ch[0]);
        } else succ(x, o->ch[1]);
    }
};
treap tree;
int n;
int main()    
{     
    scanf("%d", &n);
    int opt,x;
    for (int i=1;i<=n;i++) 
    {
        scanf("%d%d", &opt, &x);
        switch(opt)
        {
            case 1: tree.insert(x, tree.root); break;
            case 2: tree.del(x, tree.root); break;
            case 3: printf("%d\n", tree.rank(x, tree.root)); break;
            case 4: printf("%d\n", tree.kth(x, tree.root)); break;
            case 5: tree.pred(x,tree.root); printf("%d\n", ans); break;
            case 6: tree.succ(x,tree.root); printf("%d\n", ans); break;
        }
    }
    return 0;    
}
------ 本文结束 ------