splay学习笔记1 单旋splay【splay】【平衡树】

作者: wjyyy 分类: splay,学习笔记 发布时间: 2018-06-12 17:10

点击量:27

   上午+下午写了3个多小时终于写完了。。。

 

文章背景:

   昨天写了整整一下午+一晚上的AVL树还是放弃了,今天学了splay以为代码量会小一点,结果……

 

   我们需要维护的操作有

  1. 插入数x
  2. 删除数x
  3. 查询x的排名(输出比x小的数的个数+1)
  4. 查询排名为x的数
  5. 求x的前驱
  6. 求x的后继

   其中数可能重复,需要重复计算。

 

   这篇文章中所提到的splay只写了单旋,没有加很多优化,理论上只能过这种模板题,在单调的情况下会退化成链。但是这种做法是splay优化的基础,其中左右儿子可以优化到一起,使部分函数的代码量直接下降一半,而双旋只需要增加一个函数即可。

 

   splay因为随时将正在处理的节点旋转到根,并且双旋会对保持平衡有帮助,tarjan证明,有双旋的这种二叉查找树数据结构的均摊复杂度是\(O(NlogN)\),而优化后代码量较小(200行左右?),易于实现。

 

写这篇文章的目的是巩固自己写splay过程中出现的问题,以及需要注意的细节。(300行真的有点长)

 

细节总结:

  1. 节点的所有有关指针初始化要置为NULL,这是在程序里需要用到的一个量,并防止指针越界。
  2. 对于树上的任意一个节点(目前能访问得到的),一旦被改动了,就要更新它的【左孩子、右孩子、父亲】,最重要的是它的子树大小weight。同时,与这个节点有关的son、father也需要根据情况调整。
  3. 基本所有的函数返回值都是一个指针,不然旋转后返回到的还是原来的节点,节点的关系就会紊乱。
  4. 对于多数访问,一旦出现NULL都很容易越界(如第一条),一定要判断是否有这个点,然后才能访问。

 

总结:

   在学习了AVL树的思想后,splay的单旋旋转显得简单许多,不用判断过多的条件,就直接旋转,不过这只是学习splay的第一步,还有许多优化和技巧需要摸索。

 

Code:

#include<cstdio>
#include<cstring>
int max(int x,int y)
{
    return x>y?x:y;
}
int min(int x,int y)
{
    return x<y?x:y;
}
struct node
{
    int key,num,weight;
    node *ls,*rs,*fa;
    node(int key)
    {
        this->key=key;
        ls=NULL;
        rs=NULL;
        fa=NULL;
        num=1,weight=1;
    }
    node()
    {
        ls=NULL;
        rs=NULL;
        fa=NULL;
    }
}*root;
int getw(node *rt)
{
    if(rt==NULL)
        return 0;
    return rt->weight;
}
node *zig(node *rt)
{
    node *tmp=rt->fa;
    if(tmp->fa)
    {
        if(tmp->fa->ls&&tmp->fa->ls==tmp)
            tmp->fa->ls=rt;
        else
            tmp->fa->rs=rt;
    }
    rt->fa=tmp->fa;
    tmp->ls=rt->rs;
    if(tmp->ls)
        tmp->ls->fa=tmp;
    rt->rs=tmp;
    tmp->fa=rt;
    //更新tmp和rt
    tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
    rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;

    return rt;
}
node *zag(node *rt)
{
    node *tmp=rt->fa;
    if(tmp->fa)
    {
        if(tmp->fa->ls&&tmp->fa->ls==tmp)
            tmp->fa->ls=rt;
        else
            tmp->fa->rs=rt;
    }
    rt->fa=tmp->fa;
    tmp->rs=rt->ls;
    if(tmp->rs)
        tmp->rs->fa=tmp;
    rt->ls=tmp;
    tmp->fa=rt;
    tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
    rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;

    return rt;
}
node *Insert(node *rt,int x)
{
    if(rt==NULL)
    {
        rt=new node(x);
        return rt;
    }
    if(rt->key==x)
        rt->num++;
    if(rt->key>x)
    {
        rt->ls=Insert(rt->ls,x);
        rt->ls->fa=rt;
        rt=zig(rt->ls);
    }
    if(rt->key<x)
    {
        rt->rs=Insert(rt->rs,x);
        rt->rs->fa=rt;

        rt=zag(rt->rs);
    }
    rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
    return rt;
}
void mid_ord(node *rt)//中序遍历用于debug
{
    if(rt->ls)
        mid_ord(rt->ls);
    for(int i=1;i<=rt->num;i++)
        printf("%d ",rt->key);
    if(rt->rs)
        mid_ord(rt->rs);
    return;
}
node *Min(node *rt)
{
    if(rt->ls)
    {
        node *tmp=Min(rt->ls);
        rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
        return tmp;
    }

    //没有左孩子,找到了后继
    if(rt->fa->ls==rt)
        rt->fa->ls=rt->rs;
    else
        rt->fa->rs=rt->rs;
    if(rt->rs)
        rt->rs->fa=rt->fa;
    return rt;
}

node *Max(node *rt)
{
    if(rt->rs)
    {
        node *tmp=Max(rt->rs);
        rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
        return tmp;
    }

    if(rt->fa->ls==rt)
        rt->fa->ls=rt->ls;
    else
        rt->fa->rs=rt->ls;
    if(rt->ls)
        rt->ls->fa=rt->fa;
    return rt;
}
node *Delete(node *rt,int x)
{
    if(rt->key==x)
    {
        if(rt->num>1)
        {
            rt->num--;
            rt->weight--;
            return rt;
        }
        else
        {
            if(!rt->ls&&!rt->rs)
                return NULL;
            //找前驱或后继
            if(rt->rs)
            {
                //找右子树最小的
                node *tmp=Min(rt->rs);
                tmp->ls=rt->ls;
                tmp->rs=rt->rs;
                tmp->fa=rt->fa;
                if(rt->fa)
                {
                    if(rt->fa->ls==rt)
                        rt->fa->ls=tmp;
                    else
                        rt->fa->rs=tmp;
                }
                if(tmp->ls)
                    tmp->ls->fa=tmp;
                if(tmp->rs)
                    tmp->rs->fa=tmp;

                tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;

                return tmp;
            }
            else
            {
                //找左子树最大的
                node *tmp=Max(rt->ls);
                tmp->ls=rt->ls;
                tmp->rs=rt->rs;
                tmp->fa=rt->fa;
                if(rt->fa)
                {
                    if(rt->fa->ls==rt)
                        rt->fa->ls=tmp;
                    else
                        rt->fa->rs=tmp;
                }
                if(tmp->ls)
                    tmp->ls->fa=tmp;
                if(tmp->rs)
                    tmp->rs->fa=tmp;

                tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
                return tmp;
            }
        }
    }
    else
    {
        if(x<rt->key)
        {
            rt->ls=Delete(rt->ls,x);
            if(rt->ls)
                rt=zig(rt->ls);
        }
        else
        {
            rt->rs=Delete(rt->rs,x);
            if(rt->rs)
                rt=zag(rt->rs);
        }
    }
    rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
    return rt;
}


int Rank(node *rt,int x)
{
    if(getw(rt->ls)<x&&x<=getw(rt->ls)+rt->num)
        return rt->key;
    if(getw(rt->ls)>=x)
        return Rank(rt->ls,x);

    return Rank(rt->rs,x-getw(rt->ls)-rt->num);
}

int Rankof(node *rt,int x)
{
    if(!rt)
        return 0;
    if(x<=rt->key)
        return Rankof(rt->ls,x);
    return getw(rt->ls)+rt->num+Rankof(rt->rs,x);
}
int getpre(node *rt,int x)
{
    if(rt==NULL)
        return -12345678;
    if(x<=rt->key)
        return getpre(rt->ls,x);
    return max(rt->key,getpre(rt->rs,x));
}
int getnxt(node *rt,int x)
{
    if(rt==NULL)
        return 12345678;
    if(x>=rt->key)
        return getnxt(rt->rs,x);
    return min(rt->key,getnxt(rt->ls,x));//递归技巧
}
int main()
{
    root=NULL;
    int n,a,b;
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%d%d",&a,&b);
        if(a==1)
            root=Insert(root,b);
        else if(a==2)
            root=Delete(root,b);
        else if(a==3)
            printf("%d\n",Rankof(root,b)+1);
        else if(a==4)
            printf("%d\n",Rank(root,b));
        else if(a==5)
            printf("%d\n",getpre(root,b));
        else
            printf("%d\n",getnxt(root,b));
        /*mid_ord(root);
        printf("\n");*///debug用
    }
    return 0;
}

 

说点什么

avatar
  Subscribe  
提醒
/* */