洛谷 P1972 [SDOI2009]HH的项链 题解【树状数组】【区间统计】
点击量:67
500000的数据要么是$ O(N)$要么是$ O(NlogN)$了吧。
题目描述
HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,他的项链变得越来越长。有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?这个问题很难回答……因为项链实在是太长了。于是,他只好求助睿智的你,来解决这个问题。
输入输出格式
输入格式:
第一行:一个整数N,表示项链的长度。
第二行:N 个整数,表示依次表示项链中贝壳的编号(编号为0 到1000000 之间的整数)。
第三行:一个整数M,表示HH 询问的个数。
接下来M 行:每行两个整数,L 和R(1 ≤ L ≤ R ≤ N),表示询问的区间。
输出格式:
M 行,每行一个整数,依次表示询问对应的答案。
输入输出样例
输入样例#1:61 2 3 4 3 531 23 52 6输出样例#1:224说明
数据范围:
对于100%的数据,N <= 500000,M <= 200000。
首先暴力是处理每个询问找到是否有used,统计区间不同种贝壳的个数。暴力脸白一点有50分——
Code:(50pts)
#include<cstdio>
#include<cstring>
int a[501000];
bool used[1001000];
int main()
{
int n,m,u,v;
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
int ans=0;
memset(used,0,sizeof(used));
scanf("%d%d",&u,&v);
for(int j=u;j<=v;j++)
if(!used[a[j]])
{
used[a[j]]=true;
ans++;
}
printf("%d\n",ans);
}
return 0;
}
接下来介绍正解:
对于统计,我们每次只需要每种颜色中的一个做考虑。如果我们将询问的区间左端升序排列,从左向右让左端点递增选取区间。我们暂且认为做出贡献的颜色是最左边的一个,因为如果统计了最左边的那个,此时右边的任何一个相同的对答案都没有影响。
因为我们要维护区间和,并且在不断修改单点,树状数组是一种方便的选择,那么树状数组中维护的是这个区间对颜色种类所做出的贡献,也就是在当前情况下做出贡献的点,所代表的数值是1,没有做出贡献的点代表的值是0。
同时,如果一个点被舍弃了,也就是左端点在递增时越过它了,那么它现在就没有贡献了,如果它后面有相同的颜色的点,那么贡献由它来做,否则这个颜色就没有贡献了。
在一开始的情况下,区间权值为1的点,都是各种颜色第一次出现的位置,当区间扫过这个位置,后面出现的才能跟上,这样既防止了重复计算,也减少了冗余枚举,时间复杂度是$ O(n+m)logn$。查询时直接返回区间求和就可以了。
Code:(100pts)
#include<cstdio>
#include<cstring>
#include<algorithm>
int c[501000];
struct ques
{
int l,r;
int i,ans;
friend bool operator <(ques x,ques y)
{
return x.l<y.l;
}
}q[233333];
int lowbit(int x)
{
return x&(-x);
}
int n;
void add(int num,int x)
{
while(num<=n)
{
c[num]+=x;
num+=lowbit(num);
}
return;
}
int ask(int x)
{
int ans=0;
while(x)
{
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
int query(int l,int r)
{
return ask(r)-ask(l-1);
}
bool cmp(ques x,ques y)
{
return x.i<y.i;
}
int lst[1001000],nxt[501000];
//lst是上一次这个颜色出现的位置,def=0
//nxt[i]是下一次i所在位置颜色出现的位置,如果lst[nxt[i]]==i,def=0
int main()
{
int col[501000];
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&col[i]);
if(lst[col[i]]==0)
add(i,1);
nxt[lst[col[i]]]=i;
lst[col[i]]=i;
}
int m;
scanf("%d",&m);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i].l,&q[i].r);
q[i].i=i;
}
std::sort(q+1,q+m+1);
int cnt=1;
for(int i=1;i<=n;i++)
{
while(q[cnt].l==i)
{
q[cnt].ans=query(i,q[cnt].r);
cnt++;
}
add(i,-1);
if(nxt[i])
add(nxt[i],1);
}
std::sort(q+1,q+m+1,cmp);//正序输出
for(int i=1;i<=m;i++)
printf("%d\n",q[i].ans);
return 0;
}
说点什么