洛谷 P2389/U19030 电脑班的裁员 题解【贪心】【DP】【堆】
点击量:148
手写堆吊打优先队列系列
题目背景
有一天,ZZY无意中发现了自己三年前出的题(luogu2389 电脑班的裁员),并觉得500的数据太水了,决定加强一波。
苦思冥想之后,ZZY确信自己已经想出了一种美妙的$O(n^2)$的DP解法,于是决定把数据加到1000。但不幸的是,ZZY手一滑,多按了3个0……
现在就连ZZY的标程也过不了这个变态数据了【假装过不了】,聪明的你能帮帮他吗?
———————————————原题目背景———————————————
隔壁的新初一电脑班刚考过一场试,又到了裁员时间,老师把这项工作交给了ZZY来进行。而ZZY最近忙着刷题,就把这重要的任务推给了你。
题目描述
ZZY有独特的裁员技巧:每个同学都有一个考试得分$ a_i(-10^{8}\le a_i\le10^{8})$,在n个同学(n≤1000000)中选出不大于m段($ m≤\frac n2$)相邻的同学留下,裁掉未被选中的同学,使剩下同学的得分和最大。
输入输出格式
输入格式:
第一行为n,m,第二行为第1~n位同学的得分。
输出格式:
一个数s,为最大得分和。
输入输出样例
输入样例#1:5 31 -1 1 -1 1输出样例#1:3说明
【数据范围】
对于20%的数据,n<=500;
对于50%的数据,n<=2000;
对于70%的数据,n<=200000;
对于100%的数据,n<=1000000,m<=n/2,|ai|<=1e8(记得开long long)。
前言:
因为出题人想到了神级$ O(N)$做法,所以把数据加强到1e6。因为priority_queue身为STL::,所以自带大常数。我曾经测试过随机数据,手写堆是比STL优先队列快3到5倍的。不过在考场上紧张的环境下,1e5以下的数据还是可以任性地使用优先队列的。
题解:
原作者题解:http://xxccxcxc.top/post/76/
翻译成人话就是在一个长为n的数列中选出可以相连的k≤m段,使得这k段之和最大。
首先是原题P2389,数据范围为500,$ O(n^3)$运气好或者m较小的情况下是可以过的。算法是DP,令f[i][j]表示第i组以j结尾的最大和。每个数a[i]可以从a[i-1]处转移同一组数列,也可以从a[1…i-1]处转移出新的一组数列,最后遍历一遍以每一个元素结束的第m组,输出最大的结果。
DP可以被优化到时间$ O(n^2),$空间$ O(m)$。就是令f[i][j][2]表示第i组到j结尾的,同时存储j是否被选,因为是否被选这一状态,使得我们对于a[i]的转移可以只从a[i-1]过来,那么第二维就可以滚掉了。
堆优化的做法是一种贪心,显然连在一起的正数和连在一起的负数一定是一起被选(它们总可以被归为一段),记得处理0。这样一来数列中一定是正负交替的(如果不交替还可以合),而根据贪心,两头的负数可以直接去掉,不去掉是会影响答案正确性的,负数不选,对k也没有影响,因为k是可以小于等于m的。
此时,如果正数的个数小于等于所求m,直接输出正数之和,否则,我们要减小正数的个数。
我们每次选绝对值最小的一项合并,对数列影响最小,例如$ a_{i-1},a_{i},a_{i+1}$中$ a_{i}$是整个数列中绝对值最小的元素,那把这三个元素合并,对大小的影响最小。如果$ a_{i}$是个正数,那么它被合并后结果是个负数,$ \Rightarrow $2负数+1正数=1负数,相当于消掉了一个正数,判断此时正数是否小于m;如果$ a_{i}$是个负数,那么它被合并后结果是个正数,2正数+1负数=1正数,也相当于消掉了一个正数,同样判断是否正数个数小于m。如果在首尾两端,要注意是两个数相合,当合并后的数是负数时要舍掉。
合并的方法用指针链表就可以了,方便遍历,把指针带入堆中可以打懒惰删除标记,也就是上面提到的删3个数进1个数,那么有两个数在堆中是还没有被处理到的,这样给它们打上删除标记,等它们到达堆顶时,就直接删除。
Code:($ O(n^3)$dp)
#include<cstdio>
#include<cstring>
int max(int x,int y)
{
return x>y?x:y;
}
int f[510][510];//第i段以j结尾
int a[510];
int main()
{
int n,k;
int ans=0;
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
for(int j=1;j<=k&&j<=i;j++)
f[j][i]=f[j][i-1];
for(int j=1;j<=k&&j<=i;j++)
{
for(int l=0;l<i;l++)
f[j][i]=max(f[j][i],f[j-1][l]);
f[j][i]+=a[i];
ans=max(ans,f[j][i]);
}
}
printf("%d\n",ans);
return 0;
}
Code:(手写堆万岁)
#include<cstdio>
#include<cstring>
long long labs(long long x)
{
return x>0?x:-x;
}
struct node
{
long long v;
bool del;
node *nxt,*lst;
node(long long v,node *lst)
{
del=false;
this->v=v;
this->lst=lst;
nxt=NULL;
}
node()
{
del=false;
lst=NULL;
nxt=NULL;
}
}head,*tail=NULL;
struct statu
{
node *pla;
long long v;
statu(long long v,node *pla)
{
this->v=v;
this->pla=pla;
}
statu()
{
pla=NULL;
}
friend bool operator <(statu a,statu b)
{
return a.v<b.v;
}
}heap[1010000];
int cnt=0;
void push(statu x)
{
heap[++cnt]=x;
int i=cnt;
while(i>1&&heap[i]<heap[i>>1])
{
statu t=heap[i];
heap[i]=heap[i>>1];
heap[i>>1]=t;
i>>=1;
}
return;
}
void pop()
{
heap[1]=heap[cnt--];
int i=1;
while((i<<1)<=cnt&&(heap[i<<1]<heap[i]||heap[i<<1|1]<heap[i]))
{
if((i<<1)==cnt||heap[i<<1]<heap[i<<1|1])
{
statu t=heap[i];
heap[i]=heap[i<<1];
heap[i<<1]=t;
i<<=1;
}
else
{
statu t=heap[i];
heap[i]=heap[i<<1|1];
heap[i<<1|1]=t;
i=i<<1|1;
}
}
return;
}
long long a[1010000];
int main()
{
tail=&head;
int n,k,len=0;
long long tmp=0;
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
if(a[i]*tmp>0||!a[i])
{
tmp+=a[i];
continue;
}
else if(i!=1)
{
if(tmp<=0&&tail==&head)
{
tmp=a[i];
continue;
}
tail->nxt=new node(tmp,tail);
tail=tail->nxt;
push(statu(labs(tmp),tail));
if(tmp>0)
len++;
}
tmp=a[i];
}
if(tmp>0)
{
tail->nxt=new node(tmp,tail);
tail=tail->nxt;
push(statu(labs(tmp),tail));
len++;//len是正数个数
}
while(len>k)
{
statu a=heap[1];
pop();
if(a.pla->del)
continue;
node *l=a.pla->lst;
node *r=a.pla->nxt;
a.pla->del=true;
long long v=a.pla->v;
len-=(v>0);
if(l!=&head)
{
v+=l->v;
l->del=true;
len-=(l->v>0);
l=l->lst;
}
if(r!=NULL)
{
v+=r->v;
r->del=true;
len-=(r->v>0);
r=r->nxt;
}
if(v<=0)//判断是否在两端
{
if(l==&head)
{
head.nxt=r;
r->lst=&head;
continue;
}
if(r==NULL)
{
l->nxt=NULL;
tail=l;
continue;
}
}
node *p=new node(v,l);
if(r!=NULL)
r->lst=p;
if(l!=NULL)
l->nxt=p;
p->nxt=r;
if(v>0)
len++;
push(statu(labs(v),p));
}
long long ans=0;
node *p=&head;
while(p->nxt!=NULL)
{
p=p->nxt;
if(p->v>0)
ans+=p->v;
}
printf("%lld\n",ans);
return 0;
}
说点什么