51nod 1600 Simple KMP【KMP】【SAM】【树链剖分】
点击量:113
利用了 KMP 性质和 SAM 这一“数据结构”的特点并进行了动态维护。
题目描述
对于一个字符串 $S$,我们定义 $fail[i]$,表示最大的 $x$ 使得 $S[1..x]=S[i-x+1..i]$,满足 $(x<i)$。
显然对于一个字符串,如果我们将每个 $0\le i\le |S|$ 看成一个结点,除了 $i=0$ 以外 $i$ 向 $fail[i]$ 连边,这是一颗树的形状,根是 $0$。
我们定义这棵树是 $G(S)$,设 $f(S)$ 是 $G(S)$ 中除了 $0$ 号点以外所有点的深度之和,其中 $0$ 号点的深度为 $-1$。
定义 $key(S)$ 等于 $S$ 的所有非空子串 $S’$ 的 $f(S’)$ 之和。
给定一个字符串 $S$,现在你要实现以下几种操作:
1.在 $S$ 最后面加一个字符;
2.询问 $key(S)$。
善良的出题人不希望你的答案比
long long
大,所以你需要将答案对 $10^9+7$ 取模。输入格式
第一行一个正整数 $Q$,
第二行一个长度为 $Q$ 的字符串 $S$。
输出格式
输出 $Q$ 行,第 $i$ 行表示前 $i$ 个字符组成的字符串的答案。
输入样例
5 abaab
输出样例
0 0 1 4 9
数据范围与约定
$Q\le 10^5$,$S$ 中只含小写字母。
题解:
KMP 习惯相关,下文的 $nxt$ 表示题面中的 $fail$。下文中的 $n$ 表示题面中的 $Q$。
考虑单独一个长为 $n$ 的字符串 $S$。建立 $S$ 的 $fail$ 树,那么 $n$ 号点的深度就是 $nxt[n]$ 的深度 $+1$。$S[1,nxt[n]]=S[n-nxt[n]+1,n]$,因此 $S[1,nxt[nxt[n]]]=S[n-nxt[nxt[n]]+1,n]$。这样一直推下去,就是使得 $S[1,x]=S[n-x+1,n]$ 的 $x(x<n)$ 的个数。
回到原题。我们每次增加一个节点,令 $f(i)$ 表示以 $i$ 为右端点的子串所做的贡献和。那么当我们加到第 $i$ 个字符时,就是要对于每个 $j(1\le j\le i)$ 询问 $S[j,i]$ 在 $S[1,i-1]$ 中出现了几次。
这个问题可以用后缀自动机来解决。$S[j,i]$ 是 $S[1,i]$ 的后缀,因此 $S[1,i]$ 在 SAM 上的状态节点的 parent 树祖先都会做贡献。
而此时我们对每个 $S[1,i]$ 都要求计算,这样的话我们每插入一个字符就要统计一次它的所有祖先的 $|Right|\times len,(len_i=mx[i]-mx[par[i]])$。
这样的话我们每加入一个节点,对它的所有祖先的 $|Right|$ 都有 $1$ 的贡献。
但是,由于后缀自动机的特殊性,在一些情况下会有节点的分裂,这样的话就要改变树的形态和贡献计数了。因此我们可以先把整个字符串的后缀自动机建好,然后再对 parent 树树链剖分或建立 LCT 来动态维护即可。
注意贡献是 $|Right|\times len$,所以在 $|Right|$ 动态的同时,对于每个节点还需要维护一个系数 $len$。
时间复杂度 $O(n\log^2n)$。
Code:
#include<cstdio>
#include<cstring>
#define P 1000000007
char s[100010];
int n;
int ch[26][200010],mx[200010],par[200010],pcnt;
void build()
{
int p=pcnt=1;
for(int i=1;i<=n;++i)
{
int w=s[i]-'a';
int np=++pcnt;
mx[np]=mx[p]+1;
while(p&&!ch[w][p])
{
ch[w][p]=np;
p=par[p];
}
if(!p)
par[np]=1;
else
{
int q=ch[w][p];
if(mx[q]==mx[p]+1)
par[np]=q;
else
{
int nq=++pcnt;
mx[nq]=mx[p]+1;
while(p&&ch[w][p]==q)
{
ch[w][p]=nq;
p=par[p];
}
for(int j=0;j<26;++j)
ch[j][nq]=ch[j][q];
par[nq]=par[q];
par[np]=par[q]=nq;
}
}
p=np;
}
}
#define ls (k<<1)
#define rs (k<<1|1)
#define mid ((L+R)>>1)
int v[800100],V[800100],lazy[800100];
int b[200100];
void build(int k,int L,int R)
{
if(L==R)
{
V[k]=b[L];
return;
}
build(ls,L,mid);
build(rs,mid+1,R);
V[k]=(V[ls]+V[rs])%P;
}
void pushdown(int k,int L,int R)
{
if(L==R||!lazy[k])
return;
long long x=lazy[k];
lazy[k]=0;
lazy[ls]=(lazy[ls]+x)%P;
v[ls]=(v[ls]+x*V[ls])%P;
lazy[rs]=(lazy[rs]+x)%P;
v[rs]=(v[rs]+x*V[rs])%P;
}
void change(int k,int l,int r,int L,int R,int x)
{
if(r<L||l>R)
return;
if(l<=L&&r>=R)
{
lazy[k]=(lazy[k]+x)%P;
v[k]=(v[k]+(long long)x*V[k])%P;
return;
}
pushdown(k,L,R);
change(ls,l,r,L,mid,x);
change(rs,l,r,mid+1,R,x);
v[k]=(v[ls]+v[rs])%P;
}
int ask(int k,int l,int r,int L,int R)
{
if(r<L||l>R)
return 0;
if(l<=L&&r>=R)
return v[k];
pushdown(k,L,R);
return (ask(ls,l,r,L,mid)+ask(rs,l,r,mid+1,R))%P;
}
struct edge
{
int n,nxt;
edge(int n,int nxt)
{
this->n=n;
this->nxt=nxt;
}
edge(){}
}e[200100];
int head[200100],ecnt=-1;
void add(int from,int to)
{
e[++ecnt]=edge(to,head[from]);
head[from]=ecnt;
}
int tp[200100],dfn[200100],sz[200100],d[200100],hs[200100],cnt=0;
void dfs(int x)
{
sz[x]=1;
for(int i=head[x];~i;i=e[i].nxt)
{
d[e[i].n]=d[x]+1;
dfs(e[i].n);
sz[x]+=sz[e[i].n];
hs[x]=sz[hs[x]]>sz[e[i].n]?hs[x]:e[i].n;
}
}
void dfs2(int x,int t)
{
dfn[x]=++cnt;
tp[x]=t;
if(hs[x])
dfs2(hs[x],t);
for(int i=head[x];~i;i=e[i].nxt)
if(e[i].n!=hs[x])
dfs2(e[i].n,e[i].n);
}
void modify(int u,int v,int x)
{
while(tp[u]!=tp[v])
{
if(d[tp[u]]<d[tp[v]])
{
int t=u;
u=v;
v=t;
}
change(1,dfn[tp[u]],dfn[u],1,pcnt,x);
u=par[tp[u]];
}
if(dfn[u]<dfn[v])
change(1,dfn[u],dfn[v],1,pcnt,x);
else
change(1,dfn[v],dfn[u],1,pcnt,x);
}
int Ask(int u,int v)
{
int ans=0;
while(tp[u]!=tp[v])
{
if(d[tp[u]]<d[tp[v]])
{
int t=u;
u=v;
v=t;
}
ans=(ans+ask(1,dfn[tp[u]],dfn[u],1,pcnt))%P;
u=par[tp[u]];
}
if(dfn[u]<dfn[v])
ans=(ans+ask(1,dfn[u],dfn[v],1,pcnt))%P;
else
ans=(ans+ask(1,dfn[v],dfn[u],1,pcnt))%P;
return ans;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
scanf("%s",s+1);
build();
for(int i=2;i<=pcnt;++i)
add(par[i],i);
d[1]=1;
dfs(1);
dfs2(1,1);
for(int i=1;i<=pcnt;++i)
b[dfn[i]]=mx[i]-mx[par[i]];
build(1,1,pcnt);
int p=1,ans=0,sum=0;
for(int i=1;i<=n;++i)
{
int w=s[i]-'a';
while(p&&!ch[w][p])
p=par[p];
if(p)
p=ch[w][p];
else
p=1;
sum=(sum+Ask(1,p))%P;
ans=(ans+sum)%P;
printf("%d\n",ans);
modify(1,p,1);
}
return 0;
}
说点什么