Codeforces 1101D (树形DP + GCD)

Codeforces 1101D
题意:给出一棵无向带点权树,定义$dist(x,y)$为$x,y$之间的点(包含$x,y$)的个数,$g(x,y)$为$x,y$之间所有点权$gcd$,求最大的$dist(x,y)$,并且$g(x,y) \geq 2$

由于$2 \times 10^5$以内的数最多$7$种质数相乘,所以可以对点权分解质因数,指数不重要,只考虑质因子,因为质因子相同的$gcd$一定$\geq 2$。

所以我一开始想对每个质因子建一棵树,然后在树上做类似直径的DP。
这样的树的总大小不会超过$7n$,但是空间就比较爆炸,虽然可以用map之类的搞但是麻烦得要死还带了$log$。

其实不需要建那么多棵树,直接一次DFS,找到孩子如果和自己有公质因数,那么像树直径DP那样转移即可。

注意这里距离为点的个数而不是边的条数,DP 方程要特别注意

即某个点初始化长度即更新为1,然后转移时

ans = max(ans, st_dfs[u][k].second + st_dfs[v][j].second);
st_dfs[u][k].second = max(st_dfs[u][k].second, st_dfs[v][j].second + 1);

注意加一的位置,画画图更好。

知识点:
1、时间不够了也要冷静分析,一定相信自己能找出问题
2、点权边权的问题要考虑清楚再写

//==========================Head files==========================
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#include<cmath>
#include<set>
#include<iostream>
#include<map>
#define LL long long
#define db double
#define mp make_pair
#define pr pair<int, int>
#define fir first
#define sec second
#define pb push_back
#define ms(i, j) memset(i, j, sizeof i)
using namespace std;
//==========================Templates==========================
inline int read() {
    int x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9'){if (c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9'){x = x * 10 + c - '0'; c = getchar();}
    return x * f;
}
inline LL readl() {
    LL x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9'){if (c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9'){x = x * 10 + c - '0'; c = getchar();}
    return x * f;
}
int power(int a, int b) {
    int ans = 1;
    while (b) {
        if(b & 1) ans = ans * a;
        b >>= 1; a = a * a;
    }
    return ans;
}
int power_mod(int a, int b, int mod) {
    a %= mod;
    int ans = 1;
    while (b) {
        if(b & 1) ans = (ans * a) % mod;
        b >>= 1, a = (a * a) % mod;
    }
    return ans;
}
LL powerl(LL a, LL b) {
    LL ans = 1ll;
    while (b) {
        if(b & 1ll) ans = ans * a;
        b >>= 1ll;a = a * a;
    }
    return ans;
}
LL power_modl(LL a, LL b, LL mod) {
    a %= mod;
    LL ans = 1ll;
    while (b) {
        if(b & 1ll) ans = (ans * a) % mod;
        b >>= 1ll, a = (a * a) % mod;
    }
    return ans;
}
LL gcdl(LL a, LL b) {return b == 0 ? a : gcdl(b, a % b);}
LL abssl(LL a) {return a > 0 ? a : -a;}
int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}
int abss(int a) {return a > 0 ? a : -a;}
//==========================Main body==========================
#define LD "%I64d"
#define D "%d"
#define pt printf
#define sn scanf
#define pty printf("YES\n")
#define ptn printf("NO\n")
//==========================Code here==========================
const LL MAXN = 200000 + 5;

LL n, a[MAXN], ans = 1ll;
vector<LL > st_a[MAXN];
vector<LL > G[MAXN];
vector<pair<LL, LL> > st_dfs[MAXN];

void dfs(LL u, LL fa) {
    for (LL i = 0; i < (LL)st_a[u].size(); ++i) st_dfs[u].push_back(make_pair(st_a[u][i], 1));
    for (LL i = 0; i < (LL)G[u].size(); ++i) {
        LL v = G[u][i];
        if (v != fa) {
            dfs(v, u);
            for (LL j = 0; j < (LL)st_dfs[v].size(); ++j) { // st_dfs[v][j]
                LL whw = st_dfs[v][j].first;
                for (LL k = 0; k < (LL)st_dfs[u].size(); ++k) { // st_dfs[u][k]
                    LL gg = st_dfs[u][k].first;
                    if (whw == gg) {
                        ans = max(ans, st_dfs[u][k].second + st_dfs[v][j].second);

                        //printf("ans=%d, u=%d, v=%d, whw=%d\n", ans, u, v, whw);

                        st_dfs[u][k].second = max(st_dfs[u][k].second, st_dfs[v][j].second + 1);
                    }
                }
            }
        }
    }
}

int main() {
    cin >> n;
    for (LL i = 1; i <= n; ++i) scanf("%lld", &a[i]);
    int fl = 0;
    for (LL i = 1; i <= n; ++i) if (a[i] != 1) fl = 1;
    if (!fl) return printf("0\n"), 0;
    for (LL u = 1; u <= n; ++u) {
        LL tmp = a[u];
        for (LL i = 2; i * i <= tmp; ++i) if (tmp % i == 0) {
            st_a[u].push_back(i);
            while (tmp % i == 0) tmp /= i;
        }
        if (tmp != 1) st_a[u].push_back(tmp);
        sort(st_a[u].begin(), st_a[u].end());
    }

    /*for (LL u = 1; u <= n; ++u) {
        for (LL i = 0; i < (LL)st_a[u].size(); ++i) printf("%d ", st_a[u][i]);
        puts("");
    }*/

    for (LL x, y, i = 1; i < n; ++i) {
        scanf("%lld%lld", &x, &y);
        G[x].push_back(y), G[y].push_back(x);
    }

    dfs(1, 0);

    cout << ans;

    return 0;
}

/*
5
2 2 2 2 2
1 2
2 3
1 4
4 5

3
2 2 2
1 2
2 3
*/
------ 本文结束 ------