splay学习笔记1 单旋splay【splay】【平衡树】
点击量:205
上午+下午写了3个多小时终于写完了。。。
文章背景:
昨天写了整整一下午+一晚上的AVL树还是放弃了,今天学了splay以为代码量会小一点,结果……
我们需要维护的操作有
- 插入数x
- 删除数x
- 查询x的排名(输出比x小的数的个数+1)
- 查询排名为x的数
- 求x的前驱
- 求x的后继
其中数可能重复,需要重复计算。
这篇文章中所提到的splay只写了单旋,没有加很多优化,理论上只能过这种模板题,在单调的情况下会退化成链。但是这种做法是splay优化的基础,其中左右儿子可以优化到一起,使部分函数的代码量直接下降一半,而双旋只需要增加一个函数即可。
splay因为随时将正在处理的节点旋转到根,并且双旋会对保持平衡有帮助,tarjan证明,有双旋的这种二叉查找树数据结构的均摊复杂度是$ O(NlogN)$,而优化后代码量较小(200行左右?),易于实现。
写这篇文章的目的是巩固自己写splay过程中出现的问题,以及需要注意的细节。(300行真的有点长)
细节总结:
- 节点的所有有关指针初始化要置为NULL,这是在程序里需要用到的一个量,并防止指针越界。
- 对于树上的任意一个节点(目前能访问得到的),一旦被改动了,就要更新它的【左孩子、右孩子、父亲】,最重要的是它的子树大小weight。同时,与这个节点有关的son、father也需要根据情况调整。
- 基本所有的函数返回值都是一个指针,不然旋转后返回到的还是原来的节点,节点的关系就会紊乱。
- 对于多数访问,一旦出现NULL都很容易越界(如第一条),一定要判断是否有这个点,然后才能访问。
总结:
在学习了AVL树的思想后,splay的单旋旋转显得简单许多,不用判断过多的条件,就直接旋转,不过这只是学习splay的第一步,还有许多优化和技巧需要摸索。
Code:
#include<cstdio>
#include<cstring>
int max(int x,int y)
{
return x>y?x:y;
}
int min(int x,int y)
{
return x<y?x:y;
}
struct node
{
int key,num,weight;
node *ls,*rs,*fa;
node(int key)
{
this->key=key;
ls=NULL;
rs=NULL;
fa=NULL;
num=1,weight=1;
}
node()
{
ls=NULL;
rs=NULL;
fa=NULL;
}
}*root;
int getw(node *rt)
{
if(rt==NULL)
return 0;
return rt->weight;
}
node *zig(node *rt)
{
node *tmp=rt->fa;
if(tmp->fa)
{
if(tmp->fa->ls&&tmp->fa->ls==tmp)
tmp->fa->ls=rt;
else
tmp->fa->rs=rt;
}
rt->fa=tmp->fa;
tmp->ls=rt->rs;
if(tmp->ls)
tmp->ls->fa=tmp;
rt->rs=tmp;
tmp->fa=rt;
//更新tmp和rt
tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
return rt;
}
node *zag(node *rt)
{
node *tmp=rt->fa;
if(tmp->fa)
{
if(tmp->fa->ls&&tmp->fa->ls==tmp)
tmp->fa->ls=rt;
else
tmp->fa->rs=rt;
}
rt->fa=tmp->fa;
tmp->rs=rt->ls;
if(tmp->rs)
tmp->rs->fa=tmp;
rt->ls=tmp;
tmp->fa=rt;
tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
return rt;
}
node *Insert(node *rt,int x)
{
if(rt==NULL)
{
rt=new node(x);
return rt;
}
if(rt->key==x)
rt->num++;
if(rt->key>x)
{
rt->ls=Insert(rt->ls,x);
rt->ls->fa=rt;
rt=zig(rt->ls);
}
if(rt->key<x)
{
rt->rs=Insert(rt->rs,x);
rt->rs->fa=rt;
rt=zag(rt->rs);
}
rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
return rt;
}
void mid_ord(node *rt)//中序遍历用于debug
{
if(rt->ls)
mid_ord(rt->ls);
for(int i=1;i<=rt->num;i++)
printf("%d ",rt->key);
if(rt->rs)
mid_ord(rt->rs);
return;
}
node *Min(node *rt)
{
if(rt->ls)
{
node *tmp=Min(rt->ls);
rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
return tmp;
}
//没有左孩子,找到了后继
if(rt->fa->ls==rt)
rt->fa->ls=rt->rs;
else
rt->fa->rs=rt->rs;
if(rt->rs)
rt->rs->fa=rt->fa;
return rt;
}
node *Max(node *rt)
{
if(rt->rs)
{
node *tmp=Max(rt->rs);
rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
return tmp;
}
if(rt->fa->ls==rt)
rt->fa->ls=rt->ls;
else
rt->fa->rs=rt->ls;
if(rt->ls)
rt->ls->fa=rt->fa;
return rt;
}
node *Delete(node *rt,int x)
{
if(rt->key==x)
{
if(rt->num>1)
{
rt->num--;
rt->weight--;
return rt;
}
else
{
if(!rt->ls&&!rt->rs)
return NULL;
//找前驱或后继
if(rt->rs)
{
//找右子树最小的
node *tmp=Min(rt->rs);
tmp->ls=rt->ls;
tmp->rs=rt->rs;
tmp->fa=rt->fa;
if(rt->fa)
{
if(rt->fa->ls==rt)
rt->fa->ls=tmp;
else
rt->fa->rs=tmp;
}
if(tmp->ls)
tmp->ls->fa=tmp;
if(tmp->rs)
tmp->rs->fa=tmp;
tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
return tmp;
}
else
{
//找左子树最大的
node *tmp=Max(rt->ls);
tmp->ls=rt->ls;
tmp->rs=rt->rs;
tmp->fa=rt->fa;
if(rt->fa)
{
if(rt->fa->ls==rt)
rt->fa->ls=tmp;
else
rt->fa->rs=tmp;
}
if(tmp->ls)
tmp->ls->fa=tmp;
if(tmp->rs)
tmp->rs->fa=tmp;
tmp->weight=getw(tmp->ls)+getw(tmp->rs)+tmp->num;
return tmp;
}
}
}
else
{
if(x<rt->key)
{
rt->ls=Delete(rt->ls,x);
if(rt->ls)
rt=zig(rt->ls);
}
else
{
rt->rs=Delete(rt->rs,x);
if(rt->rs)
rt=zag(rt->rs);
}
}
rt->weight=getw(rt->ls)+getw(rt->rs)+rt->num;
return rt;
}
int Rank(node *rt,int x)
{
if(getw(rt->ls)<x&&x<=getw(rt->ls)+rt->num)
return rt->key;
if(getw(rt->ls)>=x)
return Rank(rt->ls,x);
return Rank(rt->rs,x-getw(rt->ls)-rt->num);
}
int Rankof(node *rt,int x)
{
if(!rt)
return 0;
if(x<=rt->key)
return Rankof(rt->ls,x);
return getw(rt->ls)+rt->num+Rankof(rt->rs,x);
}
int getpre(node *rt,int x)
{
if(rt==NULL)
return -12345678;
if(x<=rt->key)
return getpre(rt->ls,x);
return max(rt->key,getpre(rt->rs,x));
}
int getnxt(node *rt,int x)
{
if(rt==NULL)
return 12345678;
if(x>=rt->key)
return getnxt(rt->rs,x);
return min(rt->key,getnxt(rt->ls,x));//递归技巧
}
int main()
{
root=NULL;
int n,a,b;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d%d",&a,&b);
if(a==1)
root=Insert(root,b);
else if(a==2)
root=Delete(root,b);
else if(a==3)
printf("%d\n",Rankof(root,b)+1);
else if(a==4)
printf("%d\n",Rank(root,b));
else if(a==5)
printf("%d\n",getpre(root,b));
else
printf("%d\n",getnxt(root,b));
/*mid_ord(root);
printf("\n");*///debug用
}
return 0;
}
… [Trackback]
[…] Info on that Topic: wjyyy.top/399.html […]
… [Trackback]
[…] Read More on to that Topic: wjyyy.top/399.html […]
… [Trackback]
[…] Find More Information here to that Topic: wjyyy.top/399.html […]
… [Trackback]
[…] Read More Info here on that Topic: wjyyy.top/399.html […]