Treap、基础Splay、Splay的区间求和与区间修改



更新(2018.11.18)

  1. 优化Splay代码写法
  2. 随手去掉部分注释

前言

高考剩下5天,学思的孩子高考加油啦~(至今都把学思的社徽挂在书包上呢 ^_^)
回归正题,平衡树是为了解决BST不平衡而出现的数据结构,有Treap、Splay、SBT、AVL等等,本文写的是其中两种Treap和Splay的冰山一角
Treap: Tree + Heap,翻译为“树堆”,其key满足BST的性质,而通过随机生成weight值来维持Heap的性质,换句话说,就是”玄学”
Splay: 通过不断伸展使节点到根保持平衡,灵活度大,可以完成一些线段树不能完成的操作,如区间删除等等,但是代码较长

  1. 有旋Treap:
    http://www.voidcn.com/article/p-mlxsnjuo-xh.html
  2. Splay:
    https://hrbust-acm-team.gitbooks.io/acm-book/content/data_structure/ds_part5.html
    https://www.cnblogs.com/SYCstudio/p/7674387.html
    https://blog.csdn.net/DERITt/article/details/50485008
    http://dongxicheng.org/structure/splay-tree/
    http://sukixj.com/2017/12/02/splaygg/

平衡树·Treap - HihoCoder 1325

思路
Treap模板题

	#include <cstdio>
	#include <iostream>
	#include <cstring>
	using namespace std;
	const int N = 1e5 + 15;
	const int inf = 0x3f3f3f3f;

	inline int rand(){
	    static int seed = 123456;
	    return seed = (int)seed*482711LL%2147483647;
	}

	struct Treap{
	    //key:    权值
	    //weight: 随机生成的权值,用于调节平衡
	    //sz:    树的大小
	    //pre:   父亲节点
	    //ch[2]: 孩子节点
	    //tot:   静态分配内存的指针
	    //root:  根节点
	    int key[N], weight[N], sz[N], pre[N], ch[N][2];
	    int tot, root;

	    //分配节点
	    void newNode(int& x, int pkey){
	        x = ++tot;
	        key[x]    = pkey;
	        weight[x] = rand();
	        ch[x][0] = ch[x][1] = pre[x] = 0;
	        sz[x] = 1;
	    }

	    //向上更新树的大小,和线段树相同
	    void pushUp(int x){ sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + 1; }

	    //旋转: p == 1 是右旋, p == 0是左旋
	    void rotate(int x, int p){
	        int y = pre[x], z = pre[y];
	        ch[y][p^1] = ch[x][p];  //将x的左孩子接到y的右孩子上
	        pre[ch[x][p]] = y;
	        pre[x] = z;             //将x接到pre[pre[x]]上,提升一层,注意分类讨论
	        if(z)   ch[z][ch[z][1] == y] = x;
	        else    root = x;       //若z == 0,说明x是新的root
	        ch[x][p] = y;           //将y接到x上
	        pre[y] = x;
	        pushUp(y);              //pushUp一下y,再pushUp一下x,不能颠倒,因为y在x下面
	        pushUp(x);
	    }

	    void insert(int x, int pkey, int fa = 0){
	        if(x == 0){
	            newNode(x, pkey);
	            if(fa)  ch[fa][!(pkey < key[fa])] = x;  //连接节点
	            pre[x] = fa;
	            return;
	        }
	        insert(ch[x][!(pkey < key[x])], pkey, x);
	        pushUp(x);
	        int y = ch[x][!(pkey < key[x])];
	        if(weight[y] > weight[x])  rotate(y, (key[y] < key[x]));    //最小堆维护
	    }

	    void init(){
	        tot = 0, root = 1;
	        insert(0, inf);
	        insert(root, -inf);
	    }

	    void query(int x, int pkey, int& ans){
	        if(x == 0)     return;
	        if(key[x] <= pkey){
	            ans = max(ans, key[x]);
	            query(ch[x][1], pkey, ans);
	        }else{
	            query(ch[x][0], pkey, ans);
	        }
	    }
	};

	Treap tp;

	int main(){
	    int q;
	    while(~scanf("%d", &q)){
	        tp.init();
	        char op[2];
	        int num;
	        while(q--){
	            scanf("%s%d", op, &num);
	            if(op[0] == 'I'){
	                tp.insert(tp.root, num);
	            }else{
	                int ans = -inf;
	                tp.query(tp.root, num, ans);
	                printf("%d\n", ans);
	            }
	        }
	    }
	}


Order statistic set - SPOJ ORDERS

思路
Treap模板题



平衡树·Splay - HihoCoder 1329

思路
Splay基础

#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 2e5 + 15;
const int inf = 0x3f3f3f3f;

struct SplayTree{
    int ch[N][2], pre[N], key[N];
    int root, tot;

    void rotate(int x, int p){
        int y = pre[x], z = pre[y];
        ch[y][p^1] = ch[x][p];
        pre[ch[x][p]] = y;
        pre[x] = z;
        if(~z)   ch[z][ch[z][1] == y] = x;
        ch[x][p] = y;
        pre[y] = x;
    }

    //分情况进行splay
    void splay(int x, int goal){
        while(pre[x] != goal){
            if(pre[pre[x]] == goal)     rotate(x, ch[pre[x]][0] == x);
            else{
                int y = pre[x], z = pre[y];
                int p = (ch[z][0] == y);
                if(ch[y][p^1] == x)     rotate(y, p), rotate(x, p);
                else                    rotate(x, p^1), rotate(x, p);
            }
        }
        //若goal为-1,则说明root是splay到根的
        if(goal == -1)   root = x;
    }

    void newNode(int& o, int val){
        o = tot++;
        ch[o][0] = ch[o][1] = pre[o] = -1;
        key[o] = val;
    }

    void insert(int& o, int val, int fa){
        if(o == -1){
            newNode(o, val);
            pre[o] = fa;
            splay(o, -1);
            return;
        }
        if(key[o] == val){
            splay(o, -1);
            return;
        }

        if(val < key[o])    insert(ch[o][0], val, o);
        else                insert(ch[o][1], val, o);
    }

    void init(){
        root = -1, tot = 0;
        insert(root, inf, -1);
        insert(root, -inf, -1);
    }

    int findPrev(int o, int val){
        if(o == -1)  return -1;
        if(key[o] < val){
            int ret = findPrev(ch[o][1], val);
            return ret == -1 ? o : ret;
        }else{
            return findPrev(ch[o][0], val);
        }
    }

    int findSucc(int o, int val){
        if(o == -1)  return -1;
        if(key[o] > val){
            int ret = findSucc(ch[o][0], val);
            return ret == -1 ? o : ret;
        }else{
            return findSucc(ch[o][1], val);
        }
    }

    int findNotExceed(int val){
        return key[findPrev(root, val + 1)];
    }

    //删除区间,寻找lval和rval的前驱后继
    //把'l'splay到根,把'r'splay到根的右孩子
    //那么ch[ch[rt][1]][0]就是区间[lval, rval]
    void delInterval(int lval, int rval){
        int l = findPrev(root, lval), r = findSucc(root, rval);
        splay(l, -1);
        splay(r, root);
        ch[ch[root][1]][0] = -1;
    }
};

SplayTree spt;

int main(){
    int q;
    while(~scanf("%d", &q)){
        spt.init();
        while(q--){
            char op[2];
            int l, r;
            scanf("%s", op);
            if(op[0] == 'I'){
                scanf("%d", &l);
                spt.insert(spt.root, l, -1);
            }else if(op[0] == 'Q'){
                scanf("%d", &l);
                printf("%d\n", spt.findNotExceed(l));
            }else{
                scanf("%d%d", &l, &r);
                spt.delInterval(l, r);
            }
        }
    }
}


平衡树·Splay2 - HihoCoder 1333

思路
Splay的区间修改和区间求和

#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 2e5 + 15;
const int inf = 0x3f3f3f3f;
typedef long long ll;

struct SplayTree{
    int ch[N][2], pre[N];
    ll  sum[N], lzy[N], key[N], val[N], sz[N];
    int root, tot;

    void color(int o, ll delta){
        lzy[o] += delta;
        sum[o] += sz[o] * delta;
        val[o] += delta;
    }

    void pushUp(int o){
        sz[o] = 1, sum[o] = val[o];
        if(~ch[o][0]){
            sz[o] += sz[ch[o][0]];
            sum[o] += sum[ch[o][0]];
        }
        if(~ch[o][1]){
            sz[o] += sz[ch[o][1]];
            sum[o] += sum[ch[o][1]];
        }
    }

    void pushDown(int o){
        if(lzy[o]){
            if(~ch[o][0])    color(ch[o][0], lzy[o]);
            if(~ch[o][1])    color(ch[o][1], lzy[o]);
            lzy[o] = 0;
        }
    }

    void rotate(int x, int p){
        int y = pre[x], z = pre[y];
        pushDown(y), pushDown(x);
        ch[y][p^1] = ch[x][p];
        pre[ch[x][p]] = y;
        pre[x] = z;
        if(~z)   ch[z][ch[z][1] == y] = x;
        ch[x][p] = y;
        pre[y] = x;
        pushUp(y), pushUp(x);
    }

    void splay(int x, int goal){
        while(pre[x] != goal){
            if(pre[pre[x]] == goal)     rotate(x, ch[pre[x]][0] == x);
            else{
                int y = pre[x], z = pre[y];
                int p = (ch[z][0] == y);
                if(ch[y][p^1] == x)     rotate(y, p), rotate(x, p);
                else                    rotate(x, p^1), rotate(x, p);
            }
        }

        if(goal == -1)       root = x;
        else                 pushUp(goal);
    }

    void newNode(int& o, int pkey, ll pval){
        o = tot++;
        ch[o][0] = ch[o][1] = pre[o] = -1;
        lzy[o] = 0;
        key[o] = pkey;
        sum[o] = val[o] = pval;
        sz[o] = 1;
    }

    void insert(int& o, int pkey, ll pval, int fa){
        if(o == -1){
            newNode(o, pkey, pval);
            pre[o] = fa;
            splay(o, -1);
            return;
        }
        if(key[o] == pkey){
            pushDown(o);
            splay(o, -1);
            return;
        }

        pushDown(o);
        if(pkey < key[o])    insert(ch[o][0], pkey, pval, o);
        else                 insert(ch[o][1], pkey, pval, o);
    }

    void init(){
        root = -1, tot = 0;
        insert(root, inf, 0, -1);
        insert(root, -inf, 0, -1);
    }

    int findPrev(int o, int pkey){
        if(o == -1)  return -1;
        if(key[o] < pkey){
            int ret = findPrev(ch[o][1], pkey);
            return ret == -1 ? o : ret;
        }else{
            return findPrev(ch[o][0], pkey);
        }
    }

    int findSucc(int o, int pkey){
        if(o == -1)  return -1;
        if(key[o] > pkey){
            int ret = findSucc(ch[o][0], pkey);
            return ret == -1 ? o : ret;
        }else{
            return findSucc(ch[o][1], pkey);
        }
    }

    void splayLR(int lkey, int rkey){
        int l = findPrev(root, lkey), r = findSucc(root, rkey);
        splay(l, -1);
        splay(r, root);
    }

    ll query(int lkey, int rkey){
        splayLR(lkey, rkey);
        return sum[ch[ch[root][1]][0]];
    }

    void changeInterval(int lkey, int rkey, ll pval){
        splayLR(lkey, rkey);
        color(ch[ch[root][1]][0], pval);
    }

    void delInterval(int lkey, int rkey){
        splayLR(lkey, rkey);
        ch[ch[root][1]][0] = -1;
        pushUp(ch[root][1]);
    }
};

SplayTree spt;

int main(){
    int q;
    while(~scanf("%d", &q)){
        spt.init();
        while(q--){
            char op[2];
            int l, r, id;
            ll  val;
            scanf("%s", op);
            if(op[0] == 'I'){
                scanf("%d%lld", &id, &val);
                spt.insert(spt.root, id, val, -1);
            }else if(op[0] == 'Q'){
                scanf("%d%d", &l, &r);
                printf("%lld\n", spt.query(l, r));
            }else if(op[0] == 'M'){
                scanf("%d%d%lld", &l, &r, &val);
                spt.changeInterval(l, r, val);
            }else{
                scanf("%d%d", &l, &r);
                spt.delInterval(l, r);
            }
        }
    }
}