洛谷 P3177 [HAOI2015]树上染色 题解【树形DP】【背包】

作者: wjyyy 分类: DP,图论, 发布时间: 2018-05-15 17:43

点击量:27

    毒瘤的树上DP。写了我两天。。。

题目描述

有一棵点数为 N 的树,树边有边权。给你一个在 0~ N 之内的正整数 K ,你要在这棵树中选择 K个点,将其染成黑色,并将其他 的N-K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。

 

输入输出格式

输入格式:

第一行包含两个整数 N, K 。接下来 N-1 行每行三个正整数 fr, to, dis , 表示该树中存在一条长度为 dis 的边 (fr, to) 。输入保证所有点之间是联通的。

 

输出格式:

输出一个正整数,表示收益的最大值。

 

输入输出样例

输入样例#1

3 1
1 2 1
2 3 2

输出样例#1:

3

输入样例#2:

5 2
1 2 3
1 5 1
2 3 1
2 4 2

输出样例#2:

17

说明

对于30%的数据,N<=20

对于50%的数据,N<=100

对于100%的数据,N<=2000,0<=K<=N。

 

   整体思路是在DP数组f[x][i]中,f[x][i]代表在以x为根结点的子树中有i个黑点时,当前能获得的最大收益是多少。因此我们在DP过程中转移方程既与当前子树有关系,也与树的补集有关系。

 

   状态转移和普通树形DP一样,即分组背包,枚举子树中黑点的个数

for(int i=min(k,w[x]);i>=0;i--)
    for(int j=0;j<=min(i,w[p->n]);j++)
        if(f[x][i-j]!=-1)
            f[x][i]=max(f[x][i],f[x][i-j]+f[p->n][j]);

   循环里是需要上下界的,因为上下界超出可能会出现\(i<j\)或者\(j>k\)(k是黑点的个数)同时还有枚举顺序的注意点:

 

  1. 因为是二维背包压成一维,所以要倒过来枚举,而此处调用的是\(i-j\),为保证调用的数据递减而不影响后面的数据,j是升序枚举的***这里是格外要注意并想明白的。
  2. 上述代码第三行有一个
    f[x][i-j]!=-1

    DP数组f的初值就是-1,代表这个地方没有被使用过,也就是说,不能从这个地方转移过来。代表这里不合法。不像普通背包一样,可以直接使用上一层的状态,这里是有一些约束条件的,比如对于f[u][v],u中不可能有v个黑点,或者u没有v个孩子,那么f[u][v]没有被更新过,也就是f[u][v]不合法。每次遍历到的点需要将f[x][i]置0.

 

   其次,在更新时不能只管上下的转移,我们发现,在上面的DP过程中,基本没有出现边权,而题目却和边权是密切相关的。因而我们需要转移边权。

t=n-k;
for(int i=k-1;i>=0;i--)
{
    f[x][i+1]=max(f[x][i+1],f[x][i]+fa[x]*((i+1)*(k-i-1)+(w[x]-i-1)*(t-(w[x]-i-1))));
    //表示自己是黑色
    f[x][i]+=fa[x]*(i*(k-i)+(w[x]-i)*(t-(w[x]-i)));
    //表示自己是白色
}

   对自己是白色的转移,fa[x]是进入x的那条边的边权,我们可以分析一下,若以x为根结点的子树,有i个黑点,那么经过fa[x]的黑点边有\(i \times (k-i)\)(乘法原理)

 

   同理,共有t个白点(见上代码第一行),又因为该子树有w[x]-i个白点(w[x]表示以x为根结点的子树的重量(点的个数))所以经过fa[x]的白点边有\((t-w[x]-i) \times (w[x]-i) \)

 

   自己是黑点就把白点黑点的个数各-1+1即可。

 

\(tip\):记得开long long

Code:

总时间:1280ms 最慢测试点:252ms

#include<cstdio>
#include<cstring>
long long min(long long x,long long y){return x<y?x:y;}
long long max(long long x,long long y){return x>y?x:y;}
struct node
{
    long long n,v;
    node *nxt;
    node(long long n,long long v)
    {
        this->n=n;
        this->v=v;
        nxt=NULL;
    }
    node(){nxt=NULL;}
};
node head[2002],*tail[2002];
long long n,k,f[2002][2002],w[2002],fa[2002],t;//t=n-k
bool used[2002];
void dfs(long long x)
{
    f[x][0]=0;
    node *p=&head[x];
    w[x]=true;
    while(p->nxt!=NULL)
    {
        p=p->nxt;
        if(used[p->n])
            continue;
        used[p->n]=true;
        fa[p->n]=p->v;
        dfs(p->n);
        used[p->n]=false;
        w[x]+=w[p->n];
        for(int i=min(k,w[x]);i>=0;i--)
            for(int j=0;j<=min(i,w[p->n]);j++)
                if(f[x][i-j]!=-1)
                    f[x][i]=max(f[x][i],f[x][i-j]+f[p->n][j]);
    }
    for(int i=k-1;i>=0;i--)
    {
        f[x][i+1]=max(f[x][i+1],f[x][i]+fa[x]*((i+1)*(k-i-1)+(w[x]-i-1)*(t-(w[x]-i-1))));//自己是黑色
        f[x][i]+=fa[x]*(i*(k-i)+(w[x]-i)*(t-(w[x]-i)));
    }
    return;
}
int main()
{
    long long u,v,W;
    scanf("%lld%lld",&n,&k);
    t=n-k;
    for(int i=1;i<=n;i++)
        tail[i]=&head[i];
    memset(used,0,sizeof(used));
    memset(w,0,sizeof(w));
    for(int i=1;i<n;i++)
    {
        scanf("%lld%lld%lld",&u,&v,&W);
        tail[u]->nxt=new node(v,W);
        tail[u]=tail[u]->nxt;
        tail[v]->nxt=new node(u,W);
        tail[v]=tail[v]->nxt;
    }
    memset(f,-1,sizeof(f));
    used[1]=true;
    dfs(1);
    printf("%lld\n",f[1][k]);
    return 0;
}

在某些地方可以加剪枝优化,可以做到

Code’:

总时间:488ms 最慢测试点:84ms

#include<cstdio>
#include<cstring>
long long min(long long x,long long y){return x<y?x:y;}
long long max(long long x,long long y){return x>y?x:y;}
struct node
{
    long long n,v;
    node *nxt;
    node(long long n,long long v)
    {
        this->n=n;
        this->v=v;
        nxt=NULL;
    }
    node(){nxt=NULL;}
};
node head[2002],*tail[2002];
long long n,k,f[2002][2002],w[2002],fa[2002],t;//t=n-k
bool used[2002];
void dfs(long long x)
{
    f[x][0]=0;
    node *p=&head[x];
    w[x]=true;
    while(p->nxt!=NULL)
    {
        p=p->nxt;
        if(used[p->n])
            continue;
        used[p->n]=true;
        fa[p->n]=p->v;
        dfs(p->n);
        used[p->n]=false;
        w[x]+=w[p->n];
        for(int i=min(k,w[x]);i>=0;i--)
            for(int j=0;j<=min(i,w[p->n]);j++)
                if(f[x][i-j]!=-1)
                    f[x][i]=max(f[x][i],f[x][i-j]+f[p->n][j]);
    }
    for(int i=min(k,w[x])-1;i>=k-(n-w[x])-1;i--)//剪枝的地方
    {
        f[x][i+1]=max(f[x][i+1],f[x][i]+fa[x]*((i+1)*(k-i-1)+(w[x]-i-1)*(t-(w[x]-i-1))));
        f[x][i]+=fa[x]*(i*(k-i)+(w[x]-i)*(t-(w[x]-i)));
    }
    return;
}
int main()
{
    long long u,v,W;
    scanf("%lld%lld",&n,&k);
    t=n-k;
    for(int i=1;i<=n;i++)
        tail[i]=&head[i];
    memset(used,0,sizeof(used));
    memset(w,0,sizeof(w));
    for(int i=1;i<n;i++)
    {
        scanf("%lld%lld%lld",&u,&v,&W);
        tail[u]->nxt=new node(v,W);
        tail[u]=tail[u]->nxt;
        tail[v]->nxt=new node(u,W);
        tail[v]=tail[v]->nxt;
    }
    memset(f,-1,sizeof(f));
    used[1]=true;
    dfs(1);
    printf("%lld\n",f[1][k]);
    return 0;
}

 

说点什么

avatar
  Subscribe  
提醒
/* */