51nod 1600 Simple KMP【KMP】【SAM】【树链剖分】

作者: wjyyy 分类: KMP,后缀自动机,树链剖分,线段树,解题报告 发布时间: 2019-05-05 20:24

点击量:46

利用了 KMP 性质和 SAM 这一“数据结构”的特点并进行了动态维护。

题目描述

对于一个字符串 $S$,我们定义 $fail[i]$,表示最大的 $x$ 使得 $S[1..x]=S[i-x+1..i]$,满足 $(x<i)$。

显然对于一个字符串,如果我们将每个 $0\le i\le |S|$ 看成一个结点,除了 $i=0$ 以外 $i$ 向 $fail[i]$ 连边,这是一颗树的形状,根是 $0$。

我们定义这棵树是 $G(S)$,设 $f(S)$ 是 $G(S)$ 中除了 $0$ 号点以外所有点的深度之和,其中 $0$ 号点的深度为 $-1$。

定义 $key(S)$ 等于 $S$ 的所有非空子串 $S’$ 的 $f(S’)$ 之和。

给定一个字符串 $S$,现在你要实现以下几种操作:

1.在 $S$ 最后面加一个字符;

2.询问 $key(S)$。

善良的出题人不希望你的答案比 long long 大,所以你需要将答案对 $10^9+7$ 取模。

输入格式

第一行一个正整数 $Q$,

第二行一个长度为 $Q$ 的字符串 $S$。

输出格式

输出 $Q$ 行,第 $i$ 行表示前 $i$ 个字符组成的字符串的答案。

输入样例

5
abaab

输出样例

0
0
1
4
9

数据范围与约定

$Q\le 10^5$,$S$ 中只含小写字母。

题解:

KMP 习惯相关,下文的 $nxt$ 表示题面中的 $fail$。下文中的 $n$ 表示题面中的 $Q$。

考虑单独一个长为 $n$ 的字符串 $S$。建立 $S$ 的 $fail$ 树,那么 $n$ 号点的深度就是 $nxt[n]$ 的深度 $+1$。$S[1,nxt[n]]=S[n-nxt[n]+1,n]$,因此 $S[1,nxt[nxt[n]]]=S[n-nxt[nxt[n]]+1,n]$。这样一直推下去,就是使得 $S[1,x]=S[n-x+1,n]$ 的 $x(x<n)$ 的个数。

回到原题。我们每次增加一个节点,令 $f(i)$ 表示以 $i$ 为右端点的子串所做的贡献和。那么当我们加到第 $i$ 个字符时,就是要对于每个 $j(1\le j\le i)$ 询问 $S[j,i]$ 在 $S[1,i-1]$ 中出现了几次。

这个问题可以用后缀自动机来解决。$S[j,i]$ 是 $S[1,i]$ 的后缀,因此 $S[1,i]$ 在 SAM 上的状态节点的 parent 树祖先都会做贡献。

而此时我们对每个 $S[1,i]$ 都要求计算,这样的话我们每插入一个字符就要统计一次它的所有祖先的 $|Right|\times len,(len_i=mx[i]-mx[par[i]])$。

这样的话我们每加入一个节点,对它的所有祖先的 $|Right|$ 都有 $1$ 的贡献。

但是,由于后缀自动机的特殊性,在一些情况下会有节点的分裂,这样的话就要改变树的形态和贡献计数了。因此我们可以先把整个字符串的后缀自动机建好,然后再对 parent 树树链剖分或建立 LCT 来动态维护即可。

注意贡献是 $|Right|\times len$,所以在 $|Right|$ 动态的同时,对于每个节点还需要维护一个系数 $len$。

时间复杂度 $O(n\log^2n)$。

Code:

#include<cstdio>
#include<cstring>
#define P 1000000007
char s[100010];
int n;
int ch[26][200010],mx[200010],par[200010],pcnt;
void build()
{
    int p=pcnt=1;
    for(int i=1;i<=n;++i)
    {
        int w=s[i]-'a';
        int np=++pcnt;
        mx[np]=mx[p]+1;
        while(p&&!ch[w][p])
        {
            ch[w][p]=np;
            p=par[p];
        }
        if(!p)
            par[np]=1;
        else
        {
            int q=ch[w][p];
            if(mx[q]==mx[p]+1)
                par[np]=q;
            else
            {
                int nq=++pcnt;
                mx[nq]=mx[p]+1;
                while(p&&ch[w][p]==q)
                {
                    ch[w][p]=nq;
                    p=par[p];
                }
                for(int j=0;j<26;++j)
                    ch[j][nq]=ch[j][q];
                par[nq]=par[q];
                par[np]=par[q]=nq;
            }
        }
        p=np;
    }
}
#define ls (k<<1)
#define rs (k<<1|1)
#define mid ((L+R)>>1)
int v[800100],V[800100],lazy[800100];
int b[200100];
void build(int k,int L,int R)
{
    if(L==R)
    {
        V[k]=b[L];
        return;
    }
    build(ls,L,mid);
    build(rs,mid+1,R);
    V[k]=(V[ls]+V[rs])%P;
}
void pushdown(int k,int L,int R)
{
    if(L==R||!lazy[k])
        return;
    long long x=lazy[k];
    lazy[k]=0;
    lazy[ls]=(lazy[ls]+x)%P;
    v[ls]=(v[ls]+x*V[ls])%P;
    lazy[rs]=(lazy[rs]+x)%P;
    v[rs]=(v[rs]+x*V[rs])%P;
}
void change(int k,int l,int r,int L,int R,int x)
{
    if(r<L||l>R)
        return;
    if(l<=L&&r>=R)
    {
        lazy[k]=(lazy[k]+x)%P;
        v[k]=(v[k]+(long long)x*V[k])%P;
        return;
    }
    pushdown(k,L,R);
    change(ls,l,r,L,mid,x);
    change(rs,l,r,mid+1,R,x);
    v[k]=(v[ls]+v[rs])%P;
}
int ask(int k,int l,int r,int L,int R)
{
    if(r<L||l>R)
        return 0;
    if(l<=L&&r>=R)
        return v[k];
    pushdown(k,L,R);
    return (ask(ls,l,r,L,mid)+ask(rs,l,r,mid+1,R))%P;
}
struct edge
{
    int n,nxt;
    edge(int n,int nxt)
    {
        this->n=n;
        this->nxt=nxt;
    }
    edge(){}
}e[200100];
int head[200100],ecnt=-1;
void add(int from,int to)
{
    e[++ecnt]=edge(to,head[from]);
    head[from]=ecnt;
}
int tp[200100],dfn[200100],sz[200100],d[200100],hs[200100],cnt=0;
void dfs(int x)
{
    sz[x]=1;
    for(int i=head[x];~i;i=e[i].nxt)
    {
        d[e[i].n]=d[x]+1;
        dfs(e[i].n);
        sz[x]+=sz[e[i].n];
        hs[x]=sz[hs[x]]>sz[e[i].n]?hs[x]:e[i].n;
    }
}
void dfs2(int x,int t)
{
    dfn[x]=++cnt;
    tp[x]=t;
    if(hs[x])
        dfs2(hs[x],t);
    for(int i=head[x];~i;i=e[i].nxt)
        if(e[i].n!=hs[x])
            dfs2(e[i].n,e[i].n);
}
void modify(int u,int v,int x)
{
    while(tp[u]!=tp[v])
    {
        if(d[tp[u]]<d[tp[v]])
        {
            int t=u;
            u=v;
            v=t;
        }
        change(1,dfn[tp[u]],dfn[u],1,pcnt,x);
        u=par[tp[u]];
    }
    if(dfn[u]<dfn[v])
        change(1,dfn[u],dfn[v],1,pcnt,x);
    else
        change(1,dfn[v],dfn[u],1,pcnt,x);
}
int Ask(int u,int v)
{
    int ans=0;
    while(tp[u]!=tp[v])
    {
        if(d[tp[u]]<d[tp[v]])
        {
            int t=u;
            u=v;
            v=t;
        }
        ans=(ans+ask(1,dfn[tp[u]],dfn[u],1,pcnt))%P;
        u=par[tp[u]];
    }
    if(dfn[u]<dfn[v])
        ans=(ans+ask(1,dfn[u],dfn[v],1,pcnt))%P;
    else
        ans=(ans+ask(1,dfn[v],dfn[u],1,pcnt))%P;
    return ans;
}
int main()
{
    memset(head,-1,sizeof(head));
    scanf("%d",&n);
    scanf("%s",s+1);
    build();
    for(int i=2;i<=pcnt;++i)
        add(par[i],i);
    d[1]=1;
    dfs(1);
    dfs2(1,1);
    for(int i=1;i<=pcnt;++i)
        b[dfn[i]]=mx[i]-mx[par[i]];
    build(1,1,pcnt);
    int p=1,ans=0,sum=0;
    for(int i=1;i<=n;++i)
    {
        int w=s[i]-'a';
        while(p&&!ch[w][p])
            p=par[p];
        if(p)
            p=ch[w][p];
        else
            p=1;
        sum=(sum+Ask(1,p))%P;
        ans=(ans+sum)%P;
        printf("%d\n",ans);
        modify(1,p,1);
    }
    return 0;
}

说点什么

avatar
  Subscribe  
提醒
/* */