洛谷 P3177 [HAOI2015]树上染色 题解【树形DP】【背包】
点击量:391
毒瘤的树上DP。写了我两天。。。
题目描述
有一棵点数为 N 的树,树边有边权。给你一个在 0~ N 之内的正整数 K ,你要在这棵树中选择 K个点,将其染成黑色,并将其他 的N-K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。
输入输出格式
输入格式:
第一行包含两个整数 N, K 。接下来 N-1 行每行三个正整数 fr, to, dis , 表示该树中存在一条长度为 dis 的边 (fr, to) 。输入保证所有点之间是联通的。
输出格式:
输出一个正整数,表示收益的最大值。
输入输出样例
输入样例#1
3 11 2 12 3 2输出样例#1:
3
输入样例#2:
5 2
1 2 3
1 5 1
2 3 1
2 4 2输出样例#2:
17
说明
对于30%的数据,N<=20
对于50%的数据,N<=100
对于100%的数据,N<=2000,0<=K<=N。
整体思路是在DP数组f[x][i]中,f[x][i]代表在以x为根结点的子树中有i个黑点时,当前能获得的最大收益是多少。因此我们在DP过程中转移方程既与当前子树有关系,也与树的补集有关系。
状态转移和普通树形DP一样,即分组背包,枚举子树中黑点的个数
for(int i=min(k,w[x]);i>=0;i--)
for(int j=0;j<=min(i,w[p->n]);j++)
if(f[x][i-j]!=-1)
f[x][i]=max(f[x][i],f[x][i-j]+f[p->n][j]);
循环里是需要上下界的,因为上下界超出可能会出现$ i<j$或者$ j>k$(k是黑点的个数)同时还有枚举顺序的注意点:
- 因为是二维背包压成一维,所以要倒过来枚举,而此处调用的是$ i-j$,为保证调用的数据递减而不影响后面的数据,j是升序枚举的***这里是格外要注意并想明白的。
- 上述代码第三行有一个
f[x][i-j]!=-1
DP数组f的初值就是-1,代表这个地方没有被使用过,也就是说,不能从这个地方转移过来。代表这里不合法。不像普通背包一样,可以直接使用上一层的状态,这里是有一些约束条件的,比如对于f[u][v],u中不可能有v个黑点,或者u没有v个孩子,那么f[u][v]没有被更新过,也就是f[u][v]不合法。每次遍历到的点需要将f[x][i]置0.
其次,在更新时不能只管上下的转移,我们发现,在上面的DP过程中,基本没有出现边权,而题目却和边权是密切相关的。因而我们需要转移边权。
t=n-k;
for(int i=k-1;i>=0;i--)
{
f[x][i+1]=max(f[x][i+1],f[x][i]+fa[x]*((i+1)*(k-i-1)+(w[x]-i-1)*(t-(w[x]-i-1))));
//表示自己是黑色
f[x][i]+=fa[x]*(i*(k-i)+(w[x]-i)*(t-(w[x]-i)));
//表示自己是白色
}
对自己是白色的转移,fa[x]是进入x的那条边的边权,我们可以分析一下,若以x为根结点的子树,有i个黑点,那么经过fa[x]的黑点边有$ i \times (k-i)$(乘法原理)
同理,共有t个白点(见上代码第一行),又因为该子树有w[x]-i个白点(w[x]表示以x为根结点的子树的重量(点的个数))所以经过fa[x]的白点边有$ (t-w[x]-i) \times (w[x]-i) $
自己是黑点就把白点黑点的个数各-1+1即可。
$ tip$:记得开long long
Code:
总时间:1280ms 最慢测试点:252ms
#include<cstdio>
#include<cstring>
long long min(long long x,long long y){return x<y?x:y;}
long long max(long long x,long long y){return x>y?x:y;}
struct node
{
long long n,v;
node *nxt;
node(long long n,long long v)
{
this->n=n;
this->v=v;
nxt=NULL;
}
node(){nxt=NULL;}
};
node head[2002],*tail[2002];
long long n,k,f[2002][2002],w[2002],fa[2002],t;//t=n-k
bool used[2002];
void dfs(long long x)
{
f[x][0]=0;
node *p=&head[x];
w[x]=true;
while(p->nxt!=NULL)
{
p=p->nxt;
if(used[p->n])
continue;
used[p->n]=true;
fa[p->n]=p->v;
dfs(p->n);
used[p->n]=false;
w[x]+=w[p->n];
for(int i=min(k,w[x]);i>=0;i--)
for(int j=0;j<=min(i,w[p->n]);j++)
if(f[x][i-j]!=-1)
f[x][i]=max(f[x][i],f[x][i-j]+f[p->n][j]);
}
for(int i=k-1;i>=0;i--)
{
f[x][i+1]=max(f[x][i+1],f[x][i]+fa[x]*((i+1)*(k-i-1)+(w[x]-i-1)*(t-(w[x]-i-1))));//自己是黑色
f[x][i]+=fa[x]*(i*(k-i)+(w[x]-i)*(t-(w[x]-i)));
}
return;
}
int main()
{
long long u,v,W;
scanf("%lld%lld",&n,&k);
t=n-k;
for(int i=1;i<=n;i++)
tail[i]=&head[i];
memset(used,0,sizeof(used));
memset(w,0,sizeof(w));
for(int i=1;i<n;i++)
{
scanf("%lld%lld%lld",&u,&v,&W);
tail[u]->nxt=new node(v,W);
tail[u]=tail[u]->nxt;
tail[v]->nxt=new node(u,W);
tail[v]=tail[v]->nxt;
}
memset(f,-1,sizeof(f));
used[1]=true;
dfs(1);
printf("%lld\n",f[1][k]);
return 0;
}
在某些地方可以加剪枝优化,可以做到
Code’:
总时间:488ms 最慢测试点:84ms
#include<cstdio>
#include<cstring>
long long min(long long x,long long y){return x<y?x:y;}
long long max(long long x,long long y){return x>y?x:y;}
struct node
{
long long n,v;
node *nxt;
node(long long n,long long v)
{
this->n=n;
this->v=v;
nxt=NULL;
}
node(){nxt=NULL;}
};
node head[2002],*tail[2002];
long long n,k,f[2002][2002],w[2002],fa[2002],t;//t=n-k
bool used[2002];
void dfs(long long x)
{
f[x][0]=0;
node *p=&head[x];
w[x]=true;
while(p->nxt!=NULL)
{
p=p->nxt;
if(used[p->n])
continue;
used[p->n]=true;
fa[p->n]=p->v;
dfs(p->n);
used[p->n]=false;
w[x]+=w[p->n];
for(int i=min(k,w[x]);i>=0;i--)
for(int j=0;j<=min(i,w[p->n]);j++)
if(f[x][i-j]!=-1)
f[x][i]=max(f[x][i],f[x][i-j]+f[p->n][j]);
}
for(int i=min(k,w[x])-1;i>=k-(n-w[x])-1;i--)//剪枝的地方
{
f[x][i+1]=max(f[x][i+1],f[x][i]+fa[x]*((i+1)*(k-i-1)+(w[x]-i-1)*(t-(w[x]-i-1))));
f[x][i]+=fa[x]*(i*(k-i)+(w[x]-i)*(t-(w[x]-i)));
}
return;
}
int main()
{
long long u,v,W;
scanf("%lld%lld",&n,&k);
t=n-k;
for(int i=1;i<=n;i++)
tail[i]=&head[i];
memset(used,0,sizeof(used));
memset(w,0,sizeof(w));
for(int i=1;i<n;i++)
{
scanf("%lld%lld%lld",&u,&v,&W);
tail[u]->nxt=new node(v,W);
tail[u]=tail[u]->nxt;
tail[v]->nxt=new node(u,W);
tail[v]=tail[v]->nxt;
}
memset(f,-1,sizeof(f));
used[1]=true;
dfs(1);
printf("%lld\n",f[1][k]);
return 0;
}
说点什么