splay学习笔记2 双旋splay+二进制优化【splay】【平衡树】

作者: wjyyy 分类: splay,学习笔记 发布时间: 2018-06-14 09:55

点击量:155

在网上找了一下,发现很多人都没有存父节点指针,今天实现了一下,感觉确实简洁许多。

 

   不存父指针,有一个关键操作就是赋值,这一操作可以用=号来实现,也可以在函数调用时取址,直接更改,这里我使用的是=号实现。我们的splay是从上往下找,也就是抽象为让父亲把要找的儿子/孙子强行拉到根节点处。这里的实现方法就是把某个儿子或某个孙子拉到自己这里来。

 

   同时,我用了一个*son[2]来存孩子,这样根据位运算可以很容易找到旋转的方式,可以将六种旋转方式的核心压到同一个函数里。也就是二进制优化。

 

   需要注意的地方:如果删除了一个叶子节点,它是无法旋转到根的,这里要记得特判,不然会越界。

 

   双旋splay,需要判断孙子的方向,首先儿子的方向是确定的,如果儿子是我们要找的节点,那么就做单旋;如果儿子还不是,我们就进入双旋的判断。如果是zig-zig或zag-zag,那么让当前节点做两次zig或zag,如果是zig-zag或zag-zig,要先对子节点做一遍zig或zag,再自己做一遍zag或zig。

 

    注:zig是顺时针,zag是逆时针。

 其他和单旋splay思想是一样的。

 

Code:

#include<cstdio>
#include<cstring>
int max(int x,int y)
{
    return x>y?x:y;
}
int min(int x,int y)
{
    return x<y?x:y;
}
struct node
{
    int key,num;
    int weight;
    node *son[2];

    node(int key)
    {
        this->key=key;
        num=1;
        weight=1;
        son[0]=NULL;
        son[1]=NULL;
    }
    node()
    {
        son[0]=NULL;
        son[1]=NULL;
    }
};
int getw(node *rt)//访问树的重量
{
    if(!rt)
        return 0;
    return rt->weight;
}
void maintainw(node* &rt)//维护节点重量
{
    rt->weight=rt->num+getw(rt->son[0])+getw(rt->son[1]);
}
int getdir(node *rt,int x)//如果相等返回-1,如果小了返回0,大了返回1,与旋转方向相照应
{
    if(x==rt->key)
        return -1;
    if(x>rt->key)
        return 1;
    return 0;
}
node *root=NULL;
//dir为0表示zig,dir为1表示zag
node *Rotate(node *rt,int dir)
{
    node *tmp=rt->son[dir];
    rt->son[dir]=tmp->son[dir^1];//交换儿子
    tmp->son[dir^1]=rt;//交换身份
    maintainw(rt);
    maintainw(tmp);
    return tmp;
}
node *splay(node *rt,int x)
{
    int dir=getdir(rt,x);

    if(dir==-1)
        return rt;

    if(rt->son[dir]->key==x)//如果该子树的根就是要访问的点
    {
        rt=Rotate(rt,dir);//只做单旋
        return rt;
    }
    int dir2=getdir(rt->son[dir],x);
    rt->son[dir]->son[dir2]=splay(rt->son[dir]->son[dir2],x);//先把x拉到孙子处
    if(dir!=dir2)//旋两次不同的
    {
        rt->son[dir]=Rotate(rt->son[dir],dir2);
        rt=Rotate(rt,dir);
    }
    else
    {
        rt=Rotate(rt,dir);
        rt=Rotate(rt,dir);
    }
    return rt;
}
node *Insert(node *rt,int x)
{
    if(!rt)//找到插入点
    {
        rt=new node(x);
        return rt;
    }
    int dir=getdir(rt,x);
    if(dir==-1)//如果这个数已经被访问过
    {
        rt->num++;
        maintainw(rt);
        return rt;
    }
    rt->son[dir]=Insert(rt->son[dir],x);//
    maintainw(rt->son[dir]);
    maintainw(rt);
    return rt;
}
int getmx(node *rt)//delete时找的前驱
{
    if(!rt->son[1])
        return rt->key;
    return getmx(rt->son[1]);
}
int getmn(node *rt)
{
    if(!rt->son[0])
        return rt->key;
    return getmn(rt->son[0]);
}
int delnum;//用于存储用哪个点代替,用于下一步的splay
node *Delete(node *rt,int x)
{
    int dir=getdir(rt,x);
    if(dir==-1)
    {
        if(rt->num>1)
        {
            rt->num--;
            rt->weight--;
            delnum=rt->key;
            return rt;
        }
        if(!rt->son[0]&&!rt->son[1])
        {
            delnum=66666666;//不旋转
            return NULL;
        }

        if(rt->son[0])
        {
            int k=getmx(rt->son[0]);
            rt->son[0]=splay(rt->son[0],k);
            rt->son[0]->son[1]=rt->son[1];
            delnum=k;
            return rt->son[0];
        }
        else
        {
            int k=getmn(rt->son[1]);
            rt->son[1]=splay(rt->son[1],k);
            rt->son[1]->son[0]=rt->son[0];
            delnum=k;
            return rt->son[1];
        }
    }
    rt->son[dir]=Delete(rt->son[dir],x);
    maintainw(rt);
    return rt;
}
int mx=-12345678;
void getpre(node *rt,int x)
{
    if(!rt)
        return;
    if(rt->key<x&&rt->key>mx)
        mx=rt->key;
    int dir=getdir(rt,x);
    if(dir<=0)
        getpre(rt->son[0],x);
    else
        getpre(rt->son[1],x);
    return;
}
int mn=12345678;
void getsuc(node *rt,int x)
{
    if(!rt)
        return;
    if(rt->key>x&&rt->key<mn)
        mn=rt->key;
    int dir=getdir(rt,x);
    if(dir==-1||dir==1)
        getsuc(rt->son[1],x);
    else
        getsuc(rt->son[0],x);
    return;
}
int Rank(node *rt,int x)//第x名
{
    if(x-getw(rt->son[0])<=0)
        return Rank(rt->son[0],x);
    if(x-getw(rt->son[0])<=rt->num)
        return rt->key;
    return Rank(rt->son[1],x-rt->num-getw(rt->son[0]));
}
int Rankof(node *rt,int x)
{
    if(!rt)
        return 0;
    if(rt->key>=x)
        return Rankof(rt->son[0],x);
    int k=getw(rt->son[0])+rt->num;
    return k+Rankof(rt->son[1],x);
}
int main()
{
    int n,op,a,t;
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&op,&a);
        if(op==1)
        {
            root=Insert(root,a);
            root=splay(root,a);
        }
        else if(op==2)
        {
            root=Delete(root,a);
            if(delnum!=66666666)
                root=splay(root,delnum);
        }
        else if(op==3)
        {
            printf("%d\n",Rankof(root,a)+1);
            root=splay(root,a);
        }
        else if(op==4)
        {
            t=Rank(root,a);
            root=splay(root,t);
            printf("%d\n",t);
        }
        else if(op==5)
        {
            mx=-12345678;
            getpre(root,a);
            root=splay(root,mx);
            printf("%d\n",mx);
        }
        else if(op==6)
        {
            mn=12345678;
            getsuc(root,a);
            root=splay(root,mn);
            printf("%d\n",mn);
        }
    }
    return 0;
}

 

说点什么

avatar
  Subscribe  
提醒
/* */