Fish eating fruit 沈阳网络赛(树形dp)
Fish eating fruit
\[ Time Limit: 1000 ms \quad Memory Limit: 262144 kB \]
题意
大体的题意就是给出一棵树,求每一对点之间的距离,然后把该距离存在距离 \(\mod 3\) 的位置,输出总和。
思路
令两个 \(dp\) 数组和两个辅助 \(dp\) 的数组。
\(dp1[i][j]\) 表示从 \(i\) 为起点往下到各个点距离 \(\mod 3\) 后为 \(j\) 的距离总和。
\(cnt1[i][j]\) 表示以 \(i\) 为起点往下到各个点距离 \(\mod 3\) 后为 \(j\) 的节点个数。
\(dp2[i][j]\) 表示从 \(i\) 起点往上一步后到各个点距离 \(\mod 3\) 后为 \(j\) 的距离总和。
\(cnt2[i][j]\) 表示以 \(i\) 为起点往上一步后到各个点距离 \(\mod 3\) 后为 \(j\) 的节点个数。
对于两个 \(dp\) 分别跑一遍 \(dfs\)
对于 \(dp1\) 比较好处理,直接往下 \(dfs\)
以 \(u\) 开始的答案等于从 \(v\) 开始的答案加上这一条边 \(w\) 的贡献,可以得到
\[ dp1[u][(j+w)\%3] = \sum (dp1[v][j] + cnt1[v][j]*w)\\ cnt1[u][(j+w)\%3] = \sum cnt1[v][j] \]
对于 \(dp2\) 会比较麻烦,需要用 \(fa\) 节点向上的贡献在加上 \(fa\) 节点往下的贡献在减去 \(fa\) 节点往 \(u\) 走的贡献。这些节点就是 \(u\) 往上走一步后可以走到的所有节点。这样算出真实的节点数和距离总和,然后 \(u\) 才能开始转移。
设 \(faw\) 为从 \(u\) 到 \(fa\) 的路径长度
计算真实的节点数:
\[ c[j] = cnt2[fa][j]+cnt1[fa][j]\\ c[(j+faw)\%3] -= cnt1[u][j] \]
计算真实的距离总和:
\[ d[j] = dp2[fa][j]+dp1[fa][j] \\ d[(j+faw)\%3] -= dp1[u][0]+cnt1[u][j]*faw \]
则最后的 \(dp2\) 就可以利用 \(d\) 和 \(c\) 得到了
\[ dp2[u][(j+faw)\%3] = c[j]*faw+d[j] \\ cnt2[u][(j+faw)\%3] = c[j] \]
/*************************************************************** > File Name : a.cpp > Author : Jiaaaaaaaqi > Created Time : Mon 16 Sep 2019 08:55:33 PM CST ***************************************************************/ #include <map> #include <set> #include <list> #include <ctime> #include <cmath> #include <stack> #include <queue> #include <cfloat> #include <string> #include <vector> #include <cstdio> #include <bitset> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #include <unordered_map> #define lowbit(x) x & (-x) #define mes(a, b) memset(a, b, sizeof a) #define fi first #define se second #define pb push_back #define pii pair<int, int> typedef unsigned long long int ull; typedef long long int ll; const int maxn = 1e5 + 10; const int maxm = 1e5 + 10; const ll mod = 1e9 + 7; const ll INF = 1e18 + 100; const int inf = 0x3f3f3f3f; const double pi = acos(-1.0); const double eps = 1e-8; using namespace std; int n, m; int cas, tol, T; vector< pii > vv[maxn]; ll cnt1[maxn][3], cnt2[maxn][3]; ll dp1[maxn][3], dp2[maxn][3]; void dfs1(int u, int fa) { cnt1[u][0] = 1; for(auto i : vv[u]) { int v = i.fi, w = i.se; if(v == fa) continue; dfs1(v, u); dp1[u][(0+w)%3] += (cnt1[v][0]*w%mod+dp1[v][0])%mod; dp1[u][(1+w)%3] += (cnt1[v][1]*w%mod+dp1[v][1])%mod; dp1[u][(2+w)%3] += (cnt1[v][2]*w%mod+dp1[v][2])%mod; for(int j=0; j<3; j++) dp1[u][j] %= mod; cnt1[u][(0+w)%3] += cnt1[v][0]; cnt1[u][(1+w)%3] += cnt1[v][1]; cnt1[u][(2+w)%3] += cnt1[v][2]; } } void dfs2(int u, int fa) { if(u!=1) { int faw; for(auto i : vv[u]) { if(i.fi == fa) { faw = i.se; break; } } int c[3] = { 0 }; c[0] = cnt2[fa][0]+cnt1[fa][0]; c[1] = cnt2[fa][1]+cnt1[fa][1]; c[2] = cnt2[fa][2]+cnt1[fa][2]; c[(0+faw)%3] -= cnt1[u][0]; c[(1+faw)%3] -= cnt1[u][1]; c[(2+faw)%3] -= cnt1[u][2]; ll d[3] = { 0 }; d[0] = (dp2[fa][0]+dp1[fa][0])%mod; d[1] = (dp2[fa][1]+dp1[fa][1])%mod; d[2] = (dp2[fa][2]+dp1[fa][2])%mod; d[(0+faw)%3] = ((d[(0+faw)%3] - (cnt1[u][0]*faw%mod+dp1[u][0])%mod+mod)%mod+mod)%mod; d[(1+faw)%3] = ((d[(1+faw)%3] - (cnt1[u][1]*faw%mod+dp1[u][1])%mod+mod)%mod+mod)%mod; d[(2+faw)%3] = ((d[(2+faw)%3] - (cnt1[u][2]*faw%mod+dp1[u][2])%mod+mod)%mod+mod)%mod; dp2[u][(0+faw)%3] = (c[0]*faw%mod+d[0])%mod; dp2[u][(1+faw)%3] = (c[1]*faw%mod+d[1])%mod; dp2[u][(2+faw)%3] = (c[2]*faw%mod+d[2])%mod; cnt2[u][(0+faw)%3] += c[0]; cnt2[u][(1+faw)%3] += c[1]; cnt2[u][(2+faw)%3] += c[2]; } for(auto i : vv[u]) { int v = i.fi, w = i.se; if(v == fa) continue; dfs2(v, u); } } int main() { // freopen("in", "r", stdin); while(~scanf("%d", &n)) { for(int i=1; i<=n; i++) { vv[i].clear(); } mes(dp1, 0), mes(dp2, 0); mes(cnt1, 0), mes(cnt2, 0); for(int i=1, u, v, w; i<n; i++) { scanf("%d%d%d", &u, &v, &w); u++, v++; vv[u].pb(make_pair(v, w)); vv[v].pb(make_pair(u, w)); } dfs1(1, 0); dfs2(1, 0); // for(int i=1; i<=n; i++) { // for(int j=0; j<3; j++) { // printf("dp1[%d][%d] = %lld, cnt1[%d][%d] = %lld\n", i, j, dp1[i][j], i, j, cnt1[i][j]); // } // } // cout << "-----------------" << endl; // for(int i=1; i<=n; i++) { // for(int j=0; j<3; j++) { // printf("dp2[%d][%d] = %lld, cnt2[%d][%d] = %lld\n", i, j, dp2[i][j], i, j, cnt2[i][j]); // } // } ll ans0, ans1, ans2; ans0 = ans1 = ans2 = 0; for(int i=1; i<=n; i++) { ans0 = (ans0+dp1[i][0]+dp2[i][0])%mod; ans1 = (ans1+dp1[i][1]+dp2[i][1])%mod; ans2 = (ans2+dp1[i][2]+dp2[i][2])%mod; } printf("%lld %lld %lld\n", ans0, ans1, ans2); } return 0; }
转载于//www.cnblogs.com/Jiaaaaaaaqi/p/11530717.html
还没有评论,来说两句吧...