题目链接:点我啊╭(╯^╰)╮
题目大意:
树形图,求拓扑序数量
解题思路:
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j] 为
i
i
i 在子树中拓扑序排名为
j
j
j 的方案数
有
d
p
[
x
]
[
p
1
]
dp[x][p1]
dp[x][p1]、
d
p
[
y
]
[
p
2
]
dp[y][p2]
dp[y][p2],
x
x
x 为
y
y
y 的父亲,得到新的
d
p
[
x
]
[
p
3
]
dp[x][p3]
dp[x][p3]
则
x
x
x 原来排在
p
1
p1
p1,更新为
p
3
p3
p3,
y
y
y 原来排在
p
2
p2
p2、更新为
p
4
p4
p4
问题在于
p
3
p3
p3 的范围,和怎么求方案数
若
x
x
x 的拓扑序在
y
y
y 之前,则
p
3
<
p
4
p3<p4
p3<p4
p
1
p1
p1 左边的一定在
p
3
p3
p3 左边,
p
1
p1
p1 右边的一定在
p
3
p3
p3 右边,
p
2
p2
p2 右边的一定在
p
3
p3
p3 右边
而
p
2
p2
p2 左边的可以任意摆,则
p
1
−
1
≤
p
3
−
1
≤
p
1
−
1
+
p
2
−
1
p1−1≤p3−1≤p1−1+p2−1
p1−1≤p3−1≤p1−1+p2−1,得到
p
1
≤
p
3
≤
p
1
+
p
2
−
1
p1≤p3≤p1+p2−1
p1≤p3≤p1+p2−1
左边有
p
3
−
1
p3-1
p3−1 个点,有
p
1
−
1
p1−1
p1−1 个一定来自
x
x
x 的原序列,填坑的方案数为
C
p
3
−
1
p
1
−
1
C_{p3-1}^{p1-1}
Cp3−1p1−1
右边有
s
z
x
+
s
z
y
−
p
3
sz_x+sz_y-p3
szx+szy−p3 个点,有
s
z
x
−
p
1
sz_x-p1
szx−p1 个点一定来自
x
x
x 的原序列,填坑的方案数为
C
s
z
x
+
s
z
y
−
p
3
s
z
x
−
p
1
C_{sz_x+sz_y-p3}^{sz_x-p1}
Cszx+szy−p3szx−p1
d
p
[
x
]
[
p
3
]
+
=
C
p
3
−
1
p
1
−
1
×
C
s
z
x
+
s
z
y
−
p
3
s
z
x
−
p
1
×
d
p
[
x
]
[
p
1
]
×
d
p
[
y
]
[
p
2
]
dp[x][p3] += C_{p3-1}^{p1-1} \times C_{sz_x+sz_y-p3}^{sz_x-p1} \times dp[x][p1] \times dp[y][p2]
dp[x][p3]+=Cp3−1p1−1×Cszx+szy−p3szx−p1×dp[x][p1]×dp[y][p2]
转移是
n
3
n^3
n3的,代码如下:
for p1 in [1,sz_x]
for p2 in [1,sz_y]
for p3 in [p1,p1+p2-1]
p 2 p2 p2 在转移式中只出现了一次,因此调换顺序后:
for p1 in [1,sz_x]
for p3 in [p1,p1+sz_y-1]
for p2 in [p3-p1+1,sz_y]
前缀和优化即可,时间复杂度降为 n 2 n^2 n2
若
x
x
x 的拓扑序在
y
y
y 之后,则
p
3
>
p
4
p3>p4
p3>p4
p
1
+
p
2
≤
p
3
≤
p
1
+
s
z
x
p1+p2≤p3≤p1+szx
p1+p2≤p3≤p1+szx,原来的转移式如下:
for p1 in [1,sz_x]
for p2 in [1,sz_y]
for p3 in [p1+p2,p1+sz_x]
调换顺序后如下:
for p1 in [1,sz_x]
for p3 in [p1+1,p1+sz_x]
for p2 in [1,p3-p1]
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
using namespace std;
typedef long long ll;
using pii = pair <int,int>;
const int maxn = 1e3 + 5;
const ll mod = 1e9 + 7;
int T, n, a[maxn][maxn], sz[maxn];
ll dp[maxn][maxn], C[maxn][maxn];
ll f[maxn], sum[maxn][maxn];
vector <int> g[maxn];
void dfs(int u, int fa) {
sz[u] = 1, dp[u][1] = 1;
for(auto v : g[u]) {
if(sz[v]) continue;
dfs(v, u);
for(int i=1; i<=sz[u]; i++) f[i] = dp[u][i];
memset(dp[u], 0, sizeof(dp[u]));
for(int i=1; i<=sz[v]; i++) sum[v][i] = (sum[v][i-1] + dp[v][i]) % mod;
if(a[u][v])
for(int i=1; i<=sz[u]; i++)
for(int k=i; k<=i+sz[v]-1; k++)
// for(int j=k-i+1; j<=sz[v]; j++)
(dp[u][k] += C[k-1][i-1] * C[sz[u]+sz[v]-k][sz[u]-i] % mod * f[i]\
% mod * (sum[v][sz[v]] - sum[v][k-i] + mod) % mod) %= mod;
else
for(int i=1; i<=sz[u]; i++)
for(int k=i+1; k<=i+sz[v]; k++)
// for(int j=1; j<=k-i; j++)
(dp[u][k] += C[k-1][i-1] * C[sz[u]+sz[v]-k][sz[u]-i] % mod * f[i]\
% mod * sum[v][k-i] % mod) %= mod;
sz[u] += sz[v];
}
}
int main() {
for(int i=0; i<=1e3; i++) C[i][0] = 1;
for(int i=1; i<=1e3; i++)
for(int j=1; j<=i; j++)
C[i][j] = (C[i-1][j-1] + C[i-1][j]) % mod;
scanf("%d", &T);
while(T--) {
char str;
scanf("%d", &n);
memset(a, 0, sizeof(a));
memset(sz, 0, sizeof(sz));
memset(dp, 0, sizeof(dp));
for(int i=1; i<=n; i++) g[i].clear();
for(int i=1, x, y; i<n; i++) {
scanf("%d %c %d", &x, &str, &y);
x++, y++;
if(str == '<') a[x][y] = 1;
else a[y][x] = 1;
g[x].push_back(y);
g[y].push_back(x);
}
dfs(1, 0);
ll ans = 0;
for(int i=1; i<=n; i++) ans = (ans + dp[1][i]) % mod;
printf("%lld\n", ans);
}
}