前言
本文用于纪念自己过掉的第三到黑题(雾
前置芝士
- 树形 DP,树上背包
- 向量
- 凸包与闵可夫斯基和
- 差分
- 平衡树
题目大意
给定一棵 $n$ 个节点的带权无根树,节点编号 $1\sim n$。从中选出 $m$ 个节点 $a_1,a_2,a_3,\dots,a_m$,求最大值:
其中 $\mathrm{dis}(u,v)$ 表示 $u$ 到 $v$ 的简单路径的距离。
思路
暴力
首先肯定往树形 DP 想,点的贡献不好算,一个经典的 trick 就是考虑子树内边的贡献,用 $dp_{u,x}$ 表示以 $u$ 为根的子树中,选出 $x$ 个节点,边权对最终答案的贡献。然后做树上背包,状态转移方程如下:
初始化为 $dp_{u,0}=dp_{u,1}=0$,其余为 -INF。以 1 为根,则答案为 $dp_{1,m}$。
复杂度 $O(n^2)$。
代码
#include<bits/stdc++.h>
using namespace std;
#define dbg(x) cerr<<#x<<':'<<(x)<<' '
#define dbe(x) cerr<<#x<<':'<<(x)<<endl
#define eb emplace_back
#define ep emplace
#define endl '\n'
using ll=long long;
using vi=vector<int>;
using vl=vector<ll>;
using tp=tuple<int,int>;
int main(){
cin.tie(0)->sync_with_stdio(0);
int n,m;
cin>>n>>m;
vector<vector<tp>>G(n+1);
for(int i=1,u,v,w;i<n;i++){
cin>>u>>v>>w;
G[u].eb(v,w);
G[v].eb(u,w);
}
vector<vl>dp(n+1,vl(m+1,-1e18));
vi sz(n+1,1);
auto dfs=[&](int u,int fa,auto&&dfs)->void {
dp[u][0]=dp[u][1]=0;
for(auto [v,w]:G[u]){
if(v==fa)continue;
dfs(v,u,dfs);
for(int x=min(sz[u],m);x>=0;x--){
for(int y=min(sz[v],m-x);y>=0;y--){
dp[u][x+y]=max(dp[u][x+y],dp[u][x]+dp[v][y]+(ll)w*y*(m-y));
}
}
sz[u]+=sz[v];
}
};
dfs(1,0,dfs);
// for(int i=1;i<=n;i++){
// for(int j=0;j<=m;j++)cerr<<dp[i][j]<<" \n"[j==m];
// }
cout<<dp[1][m]<<endl;
return 0;
}正解
考虑如何优化。
状态转移方程中,$wy(m-y)$ 这一项的函数图像是一个上凸壳,而 $dp_{u,x}$ 的转移又是一个 $(\max,+)$ 卷积,所以 $x$ 为自变量,$dp_{u,x}$ 一定是一个上凸壳。 把刚才暴力代码得到的 $dp$ 数组输出(第 38 至 40 行),造几个样例都可以发现 $dp_{u,x}$ 是上凸壳。
所以现在要做的事就是把 $dp_{u,x}$ 和 $dp_{v,y}+wy(m-y)$ 这两个上凸壳合并,直接用闵可夫斯基和(两凸包合并所得凸包的边集,等于两凸包所有边依次排列的边集),把两个式子做差分,变成向量,由于是上凸壳,所以其差分数组是非严格单调递减的。可以用 set 维护,每次启发式合并。
但是,我们每次合并出来的是 $dp_{v,y}$,而不是下次合并需要的 $dp_{v,y}+wy(m-y)$,需要把 $wy(m-y)$ 加进去。其相邻两项做差分得到:
$y$ 每次增加 1,上式就增加一个常数 $-2w$,也就是,要在 $dp_{v,y}$ 中,加上一个首相为 $w(m-1)$,公差为 $-2w$ 的等差数列。这就用不了 set 了,需要手写平衡树,在节点打懒标记,首相和公差的懒标记都是可以合并的。这里用的 FHQ Treap。
最终时间复杂度 $O(n\log n)$。
代码
#include<bits/stdc++.h>
using namespace std;
#define dbg(x) cerr<<#x<<':'<<(x)<<' '
#define dbe(x) cerr<<#x<<':'<<(x)<<endl
#define eb emplace_back
#define ep emplace
#define endl '\n'
using ll=long long;
using vi=vector<int>;
using vl=vector<ll>;
using tp=tuple<int,int>;
struct Treap{
using ull=unsigned long long;
struct D{
int l,r;//左右孩子
ull rnd;
int sz;//子树大小
ll z,tgc,tgd;//节点值,首相tag,公差tag
};
vector<D>tr;
vi root;//树中每个节点在平衡树中对应的根
mt19937_64 rand;
Treap(int n,ull seed):tr(1),root(n+1),rand(seed){}
int makeD(){//新建节点
tr.eb(D{0,0,rand(),1,0,0});
return tr.size()-1;
}
void tag(int u,ll c,ll d){//打标记
if(!u)return;
tr[u].z+=c+tr[tr[u].l].sz*d;
tr[u].tgc+=c;
tr[u].tgd+=d;
}
void tagrt(int u,ll c,ll d){tag(root[u],c,d);}
void lazy(int u){//下传标记
int l=tr[u].l,r=tr[u].r;
ll c=tr[u].tgc,d=tr[u].tgd;//十年 OI 一场空,不开 long long 见祖宗。
tag(l,c,d),tag(r,c+(tr[l].sz+1)*d,d);
tr[u].tgc=tr[u].tgd=0;
}
void update(int u){tr[u].sz=tr[tr[u].l].sz+tr[tr[u].r].sz+1;}//更新节点信息
int merge(int l,int r){//平衡树合并模板
if(l==0||r==0)return l|r;
int rt=-1;
if(tr[l].rnd<tr[r].rnd){
rt=l,tr[l].r=merge(tr[l].r,r);
}else{
rt=r,tr[r].l=merge(l,tr[r].l);
}
update(rt);
return rt;
}
tuple<int,int>split(int u,ll k){//平衡树按值分裂模板
if(u==0)return {0,0};
lazy(u);
int l=-1,r=-1;
if(tr[u].z>=k){//注意是从大到小排序
l=u;
auto[x,y]=split(tr[u].r,k);
tr[u].r=x,r=y;
}else{
r=u;
auto[x,y]=split(tr[u].l,k);
tr[u].l=y,l=x;
}
update(u);
return {l,r};
}
void insert(int&u,int v){//将节点 v 插入到以 u 为根的子树中
tr[v].l=tr[v].r=0;
update(v);//记得 upt!!!
auto [l,r]=split(u,tr[v].z);
u=merge(merge(l,v),r);
}
void dsu(int u,int v){//启发式合并 u,v 对应的子树
int&ru=root[u],rv=root[v];
if(tr[ru].sz<tr[rv].sz)swap(ru,rv);
auto dfs=[&](int u,auto&&dfs)->void {
if(u==0)return;
lazy(u);//记得下传标记!!!
int l=tr[u].l,r=tr[u].r;
insert(ru,u);
dfs(l,dfs),dfs(r,dfs);
};dfs(rv,dfs);
}
void build(int u){root[u]=makeD();}//新建节点
vl to_vector_ll(int u){//返回 u 对应子树对应的序列
vl a;
auto dfs=[&](int u,auto&&dfs)->void {
if(u==0)return;
lazy(u);
dfs(tr[u].l,dfs);
a.eb(tr[u].z);
dfs(tr[u].r,dfs);
};dfs(root[u],dfs);
return a;
}
};
int main(){
cin.tie(0)->sync_with_stdio(0);
int n,m;
cin>>n>>m;
vector<vector<tp>>G(n+1);
for(int i=1,u,v,w;i<n;i++){
cin>>u>>v>>w;
G[u].eb(v,w),G[v].eb(u,w);
}
Treap tr(n,230224182508);
vi sz(n+1,1);//子树大小
auto dfs=[&](int u,int fa,auto&&dfs)->void {
tr.build(u);//初始化 dp[u][0]=dp[u][1]=0;
for(auto [v,w]:G[u]){
if(v==fa)continue;
dfs(v,u,dfs);
sz[u]+=sz[v];
tr.tagrt(v,(ll)w*(m-1),-2ll*w);//dp[v][y]+wy(m-y)
tr.dsu(u,v);//启发式合并
}
};dfs(1,0,dfs);
auto dp=tr.to_vector_ll(1);
ll ans=0;//dp[1][0] 一定是 0
for(int i=0;i<m;i++)ans+=dp[i];
cout<<ans<<endl;
return 0;
}警钟/hack 数据
十年 OI 一场空,不开 long long 见祖宗。而且跟 $O(n^2)$ 的暴力对拍拍不出来错。
#1
in:
5 5
3 2 5
1 3 1
4 3 3
5 2 2
ans:
54
#2
in:
7 4
5 2 5
4 2 4
1 5 1
7 5 2
3 7 2
6 7 3
ans:
55