树上差分

BB

又是菜的一比的一天 QAQ

 

Introduction

树上差分,顾名思义,是将差分搬到树上

 

[USACO15DEC]最大流Max Flow - luogu P3128

Description
FJ给他的牛棚的N(2≤N≤50,000)个隔间之间安装了N-1根管道,隔间编号从1到N。所有隔间都被管道连通了。
FJ有K(1≤K≤100,000)条运输牛奶的路线,第i条路线从隔间si运输到隔间ti。一条运输路线会给它的两个端点处的隔间以及中间途径的所有隔间带来一个单位的运输压力,你需要计算压力最大的隔间的压力是多少。
Sample Input

5 10
3 4
1 5
4 2
5 4
5 4
5 4
3 5
4 3
4 3
1 3
3 5
5 4
1 5
3 4

Sample Output

9

Solution
裸的关于点的树上差分

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = (int)5e4 + 15;

int fa[N][25], dpt[N], val[N];
vector<int> graph[N];

void dfs(int u, int pre) {
    fa[u][0] = pre;
    for(int j = 1; j < 25; j++) {
        fa[u][j] = fa[fa[u][j - 1]][j - 1];
    }
    for(int v : graph[u]) {
        if(v == pre) {
            continue;
        }
        dpt[v] = dpt[u] + 1;
        dfs(v, u);
    }
}

inline int getlca(int u, int v) {
    if(dpt[u] > dpt[v]) {
        swap(u, v);
    }
    for(int j = 24, dis = dpt[v] - dpt[u]; j >= 0; j--) {
        if(dis >> j & 1) {
            v = fa[v][j];
        }
    }
    if(u == v) {
        return u;
    }
    for(int j = 24; j >= 0; j--) {
        if(fa[u][j] != fa[v][j]) {
            u = fa[u][j], v = fa[v][j];
        }
    }
    return fa[u][0];
}

void solve(int u, int pre) {
    for(int v : graph[u]) {
        if(v == pre) {
            continue;
        }
        solve(v, u);
    }
    val[fa[u][0]] += val[u];
}

int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        graph[u].push_back(v);
        graph[v].push_back(u);
    }
    dpt[1] = 1;
    dfs(1, 0);
    while(m--) {
        int u, v;
        scanf("%d%d", &u, &v);
        int lca = getlca(u, v);
        val[u]++;
        val[v]++;
        val[lca]--;
        val[fa[lca][0]]--;
    }
    solve(1, 0);
    int ans = *max_element(val + 1, val + 1 + n);
    printf("%d\n", ans);
}

 

运输计划 - luogu P2680

Description
公元 2044 年,人类进入了宇宙纪元。
L 国有 n 个星球,还有 n−1 条双向航道,每条航道建立在两个星球之间,这 n−1 条航道连通了 L 国的所有星球。
小 P 掌管一家物流公司, 该公司有很多个运输计划,每个运输计划形如:有一艘物流飞船需要从 ui 号星球沿最快的宇航路径飞行到 vi​ 号星球去。显然,飞船驶过一条航道是需要时间的,对于航道 j,任意飞船驶过它所花费的时间为 tj​,并且任意两艘飞船之间不会产生任何干扰。
为了鼓励科技创新, L 国国王同意小 P 的物流公司参与 L 国的航道建设,即允许小P 把某一条航道改造成虫洞,飞船驶过虫洞不消耗时间。
在虫洞的建设完成前小 P 的物流公司就预接了 m 个运输计划。在虫洞建设完成后,这 m 个运输计划会同时开始,所有飞船一起出发。当这 m 个运输计划都完成时,小 P 的物流公司的阶段性工作就完成了。
如果小 P 可以自由选择将哪一条航道改造成虫洞,试求出小 PPP 的物流公司完成阶段性工作所需要的最短时间是多少?
范围:1 <= n, m <= 300000
Sample Input

6 3
1 2 3
1 6 4
3 1 7
4 3 6
3 5 5
3 6
2 5
4 5

Sample Output

11

Solution - 树上差分 + 二分
(这题不容易想啊, wtcl QAQ
首先注意到答案是可以二分的,如果某一时间可以完成,花更多的时间当然也能完成
对时间进行二分,如果此时某一条路径的长度超过了这一时间,那么必然要选这条路径上的某一条边将其权值置为0,但是并不知道要选哪一条。由于只能选择一条边,故选择的边一定会被所有超过了这一时间的路径覆盖,因此在所有满足这一条件的边中,贪心的选择边权最大的,而最终如果最大路径 - 被选择的边 <= 二分时间,则说明这一时间内可以完成所有任务,否则不可以
最后的问题是如何判断某一条边被覆盖了k次,这当然用树上差分来做

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = (int)3e5 + 15;
const int inf = 0x3f3f3f3f;
const int BASE = 22;
struct edge {
    int v, w, nxt;
};
struct Node {
    int u, v, w;
};

edge e[N << 1];
Node inE[N];
int head[N], tot;
int fa[N][BASE];
int col[N], dpt[N], parw[N], dd[N];
vector<Node> vec;
bool used[N];

inline void init() {
    memset(head, -1, sizeof(head));
    tot = 0;
}

inline void addEdge(int u, int v, int w) {
    e[tot] = edge{v, w, head[u]};
    head[u] = tot++;
}

inline void initDFS(int u, int pre) {
    fa[u][0] = pre;
    dpt[u] = dpt[pre] + 1;
    for(int j = 1; j < BASE; j++) {
        fa[u][j] = fa[fa[u][j - 1]][j - 1];
    }
    for(int i = head[u]; ~i; i = e[i].nxt) {
        int v = e[i].v;
        if(v == pre) {
            continue;
        }
        dd[v] = dd[u] + e[i].w;
        parw[v] = e[i].w;
        initDFS(v, u);
    }
}

inline int getLca(int u, int v) {
    if(dpt[u] > dpt[v]) {
        swap(u, v);
    }
    for(int j = 0, dis = dpt[v] - dpt[u]; j < BASE; j++) {
        if(dis >> j & 1) {
            v = fa[v][j];
        }
    }
    if(u == v) {
        return u;
    }
    for(int j = BASE - 1; j >= 0; j--) {
        if(fa[u][j] != fa[v][j]) {
            u = fa[u][j], v = fa[v][j];
        }
    }
    return fa[u][0];
}

inline int getDis(int u, int v, int w) {
    return dd[u] + dd[v] - 2 * dd[w];
}

inline int dfs(int u, int pre, int cnt) {
    int ans = 0;
    for(int i = head[u]; ~i; i = e[i].nxt) {
        int v = e[i].v;
        if(v == pre) {
            continue;
        }
        ans = max(ans, dfs(v, u, cnt));
        col[u] += col[v];
    }
    if(col[u] == cnt) {
        ans = max(ans, parw[u]);
    }
    return ans;
}

inline bool check(int n, int m, int k) {
    memset(col + 1, 0, n * sizeof(int));
    int cnt = 0;
    int maxDis = 0;
    for(int i = 1; i <= m; i++) {
        int u = inE[i].u, v = inE[i].v, w = inE[i].w;
        int dis = getDis(u, v, w);
        if(dis > k) {
            col[u]++;
            col[v]++;
            col[w] -= 2;
            cnt++;
            maxDis = max(maxDis, dis);
        }
    }
    if(cnt == 0) {
        return true;
    } else {
        int p = dfs(1, 0, cnt);
        return maxDis - p <= k;
    }
}

int main() {
    init();
    int n, m;
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n - 1; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        addEdge(u, v, w);
        addEdge(v, u, w);
    }
    initDFS(1, 0);
    for(int i = 1; i <= m; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        inE[i] = Node{u, v, getLca(u, v)};
    }

    int l = 0, r = (int)1e9, ans = 0;
    while(l <= r) {
        int mid = (l + r) >> 1;
        if(check(n, m, mid)) {
            ans = mid;
            r = mid - 1;
        } else {
            l = mid + 1;
        }
    }
    printf("%d\n", ans);

    return 0;
}

 

Rikka with Intersections of Paths - Gym 102012G

Description
给定一棵n个点树,给定m条路径,问在这些路径中选择k条路径,使得他们至少交于一个顶点的方案数
范围:1 <= T <= 200, n, m <= 3*10^5
答案对10^9 + 7取模

1
3 6 2
1 2
1 3
1 1
2 2
3 3
1 2
1 3
2 3

Sample Output

10

Solution - 树上差分
一个结论:树上相交的两条路径必经过其中一条路径的lca
有了这一结论后,就可以枚举lca做。记cnt[i]为点i被路径经过的次数,lcaCnt[i]为点i作为路径的lca的次数,则ans = sum(C(cnt[i], k) - C(cnt[i] - lcaCnt[i], k))

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = (int)3e5 + 15;
const int inf = 0x3f3f3f3f;
const int BASE = 22;
const int MOD = (int)1e9 + 7;

vector<int> graph[N];
int fa[N][BASE];
int cnt[N], lcaCnt[N], dpt[N], parw[N];
int invMul[N], inv[N], mul[N];

inline int addMod(int a, int b) {
    return a + b >= MOD ? a + b - MOD : a + b;
}

inline void init() {
    invMul[0] = mul[0] = 1;
    inv[1] = invMul[1] = mul[1] = 1;
    for(int i = 2; i < N; i++) {
        mul[i] = 1LL * mul[i - 1] * i % MOD;
        inv[i] = 1LL * (MOD - MOD / i) * inv[MOD % i] % MOD;
        invMul[i] = 1LL * invMul[i - 1] * inv[i] % MOD;
    }
}

inline int comb(int n, int m) {
    return (!n || n < m ? 0 : 1LL * mul[n] * invMul[m] % MOD * invMul[n - m] % MOD);
}

inline void initDFS(int u, int pre) {
    fa[u][0] = pre;
    dpt[u] = dpt[pre] + 1;
    for(int j = 1; j < BASE; j++) {
        fa[u][j] = fa[fa[u][j - 1]][j - 1];
    }
    for(int v : graph[u]) {
        if(v == pre) {
            continue;
        }
        initDFS(v, u);
    }
}

inline int getLca(int u, int v) {
    if(dpt[u] > dpt[v]) {
        swap(u, v);
    }
    for(int j = 0, dis = dpt[v] - dpt[u]; j < BASE; j++) {
        if(dis >> j & 1) {
            v = fa[v][j];
        }
    }
    if(u == v) {
        return u;
    }
    for(int j = BASE - 1; j >= 0; j--) {
        if(fa[u][j] != fa[v][j]) {
            u = fa[u][j], v = fa[v][j];
        }
    }
    return fa[u][0];
}

inline int solve(int u, int pre, int k) {
    int ans = 0;
    for(int v : graph[u]) {
        if(v == pre) {
            continue;
        }
        ans = addMod(ans, solve(v, u, k));
        cnt[u] += cnt[v];
    }
    ans = addMod(ans, addMod(comb(cnt[u], k), MOD - comb(cnt[u] - lcaCnt[u], k)));
    return ans;
}


int main() {
    init();
    int t;
    scanf("%d", &t);
    while(t--) {
        int n, m, k;
        scanf("%d%d%d", &n, &m, &k);
        for(int i = 1; i <= n; i++) {
            graph[i].clear();
            cnt[i] = lcaCnt[i] = 0;
        }
        for(int i = 1; i <= n - 1; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            graph[u].push_back(v);
            graph[v].push_back(u);
        }
        initDFS(1, 0);
        while(m--) {
            int u, v;
            scanf("%d%d", &u, &v);
            int lca = getLca(u, v);
            cnt[u]++;
            cnt[v]++;
            cnt[lca]--;
            cnt[fa[lca][0]]--;
            lcaCnt[lca]++;
        }
        int ans = solve(1, 0, k);
        printf("%d\n", ans);
    }
    return 0;
}