洛谷 P2389/U19030 电脑班的裁员 题解【贪心】【DP】【堆】

作者: wjyyy 分类: DP,,解题报告,贪心 发布时间: 2018-07-08 17:26

点击量:20

 

   手写堆吊打优先队列系列

 

题目背景

有一天,ZZY无意中发现了自己三年前出的题(luogu2389 电脑班的裁员),并觉得500的数据太水了,决定加强一波。

 

苦思冥想之后,ZZY确信自己已经想出了一种美妙的$O(n^2)$的DP解法,于是决定把数据加到1000。但不幸的是,ZZY手一滑,多按了3个0……

 

现在就连ZZY的标程也过不了这个变态数据了【假装过不了】,聪明的你能帮帮他吗?

 

———————————————原题目背景———————————————

 

隔壁的新初一电脑班刚考过一场试,又到了裁员时间,老师把这项工作交给了ZZY来进行。而ZZY最近忙着刷题,就把这重要的任务推给了你。

 

题目描述

ZZY有独特的裁员技巧:每个同学都有一个考试得分\(a_i(-10^{8}\le a_i\le10^{8})\),在n个同学(n≤1000000)中选出不大于m段(\(m≤\frac n2\))相邻的同学留下,裁掉未被选中的同学,使剩下同学的得分和最大。

 

输入输出格式

输入格式:

第一行为n,m,第二行为第1~n位同学的得分。

 

输出格式:

一个数s,为最大得分和。

 

输入输出样例

输入样例#1:
5 3
1 -1 1 -1 1
输出样例#1:
3

说明

【数据范围】

对于20%的数据,n<=500;

对于50%的数据,n<=2000;

对于70%的数据,n<=200000;

对于100%的数据,n<=1000000,m<=n/2,|ai|<=1e8(记得开long long)。

 

前言:

   因为出题人想到了神级\(O(N)\)做法,所以把数据加强到1e6。因为priority_queue身为STL::,所以自带大常数。我曾经测试过随机数据,手写堆是比STL优先队列快3到5倍的。不过在考场上紧张的环境下,1e5以下的数据还是可以任性地使用优先队列的。

 

题解:

   原作者题解:http://xxccxcxc.top/post/76/

 

   翻译成人话就是在一个长为n的数列中选出可以相连的k≤m段,使得这k段之和最大。

   首先是原题P2389,数据范围为500,\(O(n^3)\)运气好或者m较小的情况下是可以过的。算法是DP,令f[i][j]表示第i组以j结尾的最大和。每个数a[i]可以从a[i-1]处转移同一组数列,也可以从a[1…i-1]处转移出新的一组数列,最后遍历一遍以每一个元素结束的第m组,输出最大的结果。

 

   DP可以被优化到时间\(O(n^2),\)空间\(O(m)\)。就是令f[i][j][2]表示第i组j结尾的,同时存储j是否被选,因为是否被选这一状态,使得我们对于a[i]的转移可以只从a[i-1]过来,那么第二维就可以滚掉了。

 

   堆优化的做法是一种贪心,显然连在一起的正数和连在一起的负数一定是一起被选(它们总可以被归为一段),记得处理0。这样一来数列中一定是正负交替的(如果不交替还可以合),而根据贪心,两头的负数可以直接去掉,不去掉是会影响答案正确性的,负数不选,对k也没有影响,因为k是可以小于等于m的。

 

   此时,如果正数的个数小于等于所求m,直接输出正数之和,否则,我们要减小正数的个数。

 

   我们每次选绝对值最小的一项合并,对数列影响最小,例如\(a_{i-1},a_{i},a_{i+1}\)中\(a_{i}\)是整个数列中绝对值最小的元素,那把这三个元素合并,对大小的影响最小。如果\(a_{i}\)是个正数,那么它被合并后结果是个负数,\(\Rightarrow \)2负数+1正数=1负数,相当于消掉了一个正数,判断此时正数是否小于m;如果\(a_{i}\)是个负数,那么它被合并后结果是个正数,2正数+1负数=1正数,也相当于消掉了一个正数,同样判断是否正数个数小于m。如果在首尾两端,要注意是两个数相合,当合并后的数是负数时要舍掉。

 

   合并的方法用指针链表就可以了,方便遍历,把指针带入堆中可以打懒惰删除标记,也就是上面提到的删3个数进1个数,那么有两个数在堆中是还没有被处理到的,这样给它们打上删除标记,等它们到达堆顶时,就直接删除。

 

Code:(\(O(n^3)\)dp)

#include<cstdio>
#include<cstring>
int max(int x,int y)
{
    return x>y?x:y;
}
int f[510][510];//第i段以j结尾
int a[510];
int main()
{
    int n,k;
    int ans=0;
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
        for(int j=1;j<=k&&j<=i;j++)
            f[j][i]=f[j][i-1];
        for(int j=1;j<=k&&j<=i;j++)
        {
            for(int l=0;l<i;l++)
                f[j][i]=max(f[j][i],f[j-1][l]);
            f[j][i]+=a[i];
            ans=max(ans,f[j][i]);
        }

    }
    printf("%d\n",ans);
    return 0;
}

Code:(手写堆万岁)

#include<cstdio>
#include<cstring>
long long labs(long long x)
{
    return x>0?x:-x;
}
struct node
{
    long long v;
    bool del;
    node *nxt,*lst;
     node(long long v,node *lst)
    {
        del=false;
        this->v=v;
        this->lst=lst;
        nxt=NULL;
    }
    node()
    {
        del=false;
        lst=NULL;
        nxt=NULL;
    }
}head,*tail=NULL;
struct statu
{
    node *pla;
    long long v;
    statu(long long v,node *pla)
    {
        this->v=v;
        this->pla=pla;
    }
    statu()
    {
        pla=NULL;
    }
    friend bool operator <(statu a,statu b)
    {
        return a.v<b.v;
    }
}heap[1010000];
int cnt=0;
void push(statu x)
{
    heap[++cnt]=x;
    int i=cnt;
    while(i>1&&heap[i]<heap[i>>1])
    {
        statu t=heap[i];
        heap[i]=heap[i>>1];
        heap[i>>1]=t;
        i>>=1;
    }
    return;
}
void pop()
{
    heap[1]=heap[cnt--];
    int i=1;
    while((i<<1)<=cnt&&(heap[i<<1]<heap[i]||heap[i<<1|1]<heap[i]))
    {
        if((i<<1)==cnt||heap[i<<1]<heap[i<<1|1])
        {
            statu t=heap[i];
            heap[i]=heap[i<<1];
            heap[i<<1]=t;
            i<<=1;
        }
        else
        {
            statu t=heap[i];
            heap[i]=heap[i<<1|1];
            heap[i<<1|1]=t;
            i=i<<1|1;
        }
    }
    return;
}
long long a[1010000];
int main()
{
    tail=&head;
    int n,k,len=0;
    long long tmp=0;
    scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&a[i]);
        if(a[i]*tmp>0||!a[i])
        {
            tmp+=a[i];
            continue;
        }
        else if(i!=1)
        {
            if(tmp<=0&&tail==&head)
            {
                tmp=a[i];
                continue;
            }
            tail->nxt=new node(tmp,tail);
            tail=tail->nxt;
            push(statu(labs(tmp),tail));
            if(tmp>0)
                len++;
        }
        tmp=a[i];
    }
    if(tmp>0)
    {
        tail->nxt=new node(tmp,tail);
        tail=tail->nxt;
        push(statu(labs(tmp),tail));
        len++;//len是正数个数
    }
    while(len>k)
    {
        statu a=heap[1];
        pop();
        if(a.pla->del)
            continue;
        node *l=a.pla->lst;
        node *r=a.pla->nxt;
        a.pla->del=true;
        long long v=a.pla->v;
        len-=(v>0);
        if(l!=&head)
        {
            v+=l->v;
            l->del=true;
            len-=(l->v>0);
            l=l->lst;
        }
        if(r!=NULL)
        {
            v+=r->v;
            r->del=true;
            len-=(r->v>0);
            r=r->nxt;
        }
        if(v<=0)//判断是否在两端
        {
            if(l==&head)
            {
                head.nxt=r;
                r->lst=&head;
                continue;
            }
            if(r==NULL)
            {
                l->nxt=NULL;
                tail=l;
                continue;
            }
        }
        node *p=new node(v,l);
        if(r!=NULL)
            r->lst=p;
        if(l!=NULL)
            l->nxt=p;
        p->nxt=r;
        if(v>0)
            len++;
        push(statu(labs(v),p));
    }
    long long ans=0;
    node *p=&head;
    while(p->nxt!=NULL)
    {
        p=p->nxt;
        if(p->v>0)
            ans+=p->v;
    }
    printf("%lld\n",ans);
    return 0;
}

 

说点什么

avatar
  Subscribe  
提醒
/* */