Distance Sum (AOJ2636)

この記事は帰ってきた AOJ-ICPC Advent Calendar 2022 17 日目の記事です。

木クエリゴリゴリ問題です。

問題概要

 N 頂点の木がある。頂点  i の親は  P_i であり、その間の距離は  d_i である。  1 \le K \le N について、 \min_u \sum_{i = 1} ^ {K} dist(u, i) を求めよ。

制約

  •  N \le 2 \times 10 ^ 5
  •  d_i \le 2 \times 10 ^ 5

解法

大体重心を求めたいという問題です。 N 頂点のうち、 K 頂点しか気にならないのでそれらからなる圧縮木を作ってみます。すると、重心の性質から求める  u はその圧縮木上の頂点としてよいことが分かります。(特に、 1 \le u \le K が簡単にわかる。)

さて、重心はある頂点から部分木のサイズが  N / 2 以上の方向に移動することを繰り返すことで求まる。また、圧縮木上でも普通の木と変わらず、移動すると必ず部分木のサイズが 1 以上減る(不変だとすると、その頂点は圧縮木に不必要)

このことから、圧縮木の頂点を一つ増やしたとき、重心は高々一回移動するのみである。移動した際の答えの変化はセグ木等で適当に求まるので、圧縮木を動的に変化させることを考える。

圧縮木は、オイラーツアー順に頂点を見たときに隣合った頂点の LCA を全て追加することで求まる。 そのため、新しい頂点を追加する際に既にみた頂点のうちオイラーツアー順で隣接しているものとの LCA が追加する頂点の候補である。 今、頂点  K を追加するとき、オイラーツアー順で手前の頂点を  L とし、 LCA(K, L) = C とする。  C が圧縮木に含まれていないとき、 L の元々の(圧縮木上での)親を  P とし、 (P, C), (C, L) と結びなおす。 これをオイラーツアー順で次に隣接している頂点  R についても行う。  C の親は  L, R に対する  C のうち、深い方である。

 K が圧縮木上にまだないことから、 K の子供は高々一つしかない。よって、 K と子供を結ぶ辺、 K と親を結ぶ辺が正しく貼れる。

さて、この圧縮木をどう管理するかだが、HLD と map を用いてグラフを管理することが出来る。圧縮木上で  U \rightarrow V の辺を貼る際、 U から  V に向かって 1 進んだ頂点を  W とし、 G[U][W] = V として管理すればよい。後は頑張るとなんとかなる。

この持ち方をしていれば、「辺を切る」という操作は必要なくなるので楽(辺を切る際、必ず新しいところと結ぶため)

実装

#include <bits/stdc++.h>

using namespace std;

#define ov3(a, b, c, name, ...) name
#define ll long long
#define rep2(i, a, b) for(ll i = (a); i < (b); i++)
#define rep1(i, n) rep2(i, 0, n)
#define rep0(n) rep1(iiiii, n)
#define rep(...) ov3(__VA_ARGS__, rep2, rep1, rep0)(__VA_ARGS__)
#define si(c) (ll)(c.size())
#define pll pair<ll, ll>
#define vl vector<ll>
#define all(v) v.begin(), v.end()
#define fore(e, v) for(auto &&e : v)

template <typename S, typename T> bool chmax(S &x, const T &y) { return (x < y ? x = y, true : false); }

template <typename S, typename T> bool chmin(S &x, const T &y) { return (x > y ? x = y, true : false); }

template <typename T> ostream &operator<<(ostream &os, const vector<T> &v) {
    os << "{";
    rep(i, si(v)) {
        if(i) os << ", ";
        os << v[i];
    }
    return os << "}";
}

template <typename T, typename F> struct segtree {
    int n, rn;
    vector<T> seg;
    F f;
    T id;

    segtree() = default;

    segtree(int rn, F f, T id) : rn(rn), f(f), id(id) {
        n = 1;
        while(n < rn) n <<= 1;
        seg.assign(n * 2, id);
    }

    void set(int i, T x) { seg[i + n] = x; }

    void build() {
        for(int k = n - 1; k > 0; k--) seg[k] = f(seg[k * 2], seg[k * 2 + 1]);
    }

    void update(int i, T x) {
        set(i, x);
        i += n;
        while(i >>= 1) seg[i] = f(seg[i * 2], seg[i * 2 + 1]);
    }

    T get(int i) { return seg[i + n]; }

    T prod(int l, int r) {
        T L = id, R = id;
        for(l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if(l & 1) L = f(L, seg[l++]);
            if(r & 1) R = f(seg[--r], R);
        }
        return f(L, R);
    }

    template <typename C> int max_right(int l, const C &c) {
        assert(c(id));
        if(l >= rn) return rn;
        l += n;
        T sum = id;
        do {
            while((l & 1) == 0) l >>= 1;
            if(!c(f(sum, seg[l]))) {
                while(l < n) {
                    l <<= 1;
                    auto nxt = f(sum, seg[l]);
                    if(c(nxt)) {
                        sum = nxt;
                        l++;
                    }
                }
                return l - n;
            }
            sum = f(sum, seg[l++]);
        } while((l & -l) != l);
        return rn;
    }
};

template <typename G> struct HLD {
    int n;
    G &g;
    vector<int> sub, in, out, head, rev, par, d;

    HLD(G &g) : n(si(g)), g(g), sub(n), in(n), out(n), head(n), rev(n), par(n), d(n) {}

    void dfs1(int x, int p) {
        par[x] = p;
        sub[x] = 1;
        if(g[x].size() and g[x][0] == p) swap(g[x][0], g[x].back());
        fore(e, g[x]) {
            if(e == p) continue;
            d[e] = d[x] + 1;
            dfs1(e, x);
            sub[x] += sub[e];
            if(sub[g[x][0]] < sub[e]) swap(g[x][0], e);
        }
    }

    void dfs2(int x, int p, int &t) {
        in[x] = t++;
        rev[in[x]] = x;
        fore(e, g[x]) {
            if(e == p) continue;
            head[e] = (g[x][0] == e ? head[x] : e);
            dfs2(e, x, t);
        }
        out[x] = t;
    }

    void build() {
        int t = 0;
        head[0] = 0;
        dfs1(0, -1);
        dfs2(0, -1, t);
    }

    int la(int v, int k) {
        while(1) {
            int u = head[v];
            if(in[v] - k >= in[u]) return rev[in[v] - k];
            k -= in[v] - in[u] + 1;
            v = par[u];
        }
    }

    int lca(int u, int v) {
        for(;; v = par[head[v]]) {
            if(in[u] > in[v]) swap(u, v);
            if(head[u] == head[v]) return u;
        }
    }

    template <typename T, typename Q, typename F> T query(int u, int v, const T &e, const Q &q, const F &f, bool edge = false) {
        T l = e, r = e;
        for(;; v = par[head[v]]) {
            if(in[u] > in[v]) swap(u, v), swap(l, r);
            if(head[u] == head[v]) break;
            l = f(q(in[head[v]], in[v] + 1), l);
        }
        return f(f(q(in[u] + edge, in[v] + 1), l), r);
    }

    int dist(int u, int v) { return d[u] + d[v] - 2 * d[lca(u, v)]; }

    int jump(int s, int t, int i) {
        if(!i) return s;
        int l = lca(s, t);
        int dst = d[s] + d[t] - d[l] * 2;
        if(dst < i) return -1;
        if(d[s] - d[l] >= i) return la(s, i);
        i -= d[s] - d[l];
        return la(t, d[t] - d[l] - i);
    }
};

template <typename T> struct edge {
    int to;
    T cost;

    edge(int to, T cost) : to(to), cost(cost) {}

    edge &operator=(const int &x) {
        to = x;
        return *this;
    }

    operator int() const { return to; }
};

int main() {
    cin.tie(0), cout.tie(0);
    ios::sync_with_stdio(false);

    int n;
    cin >> n;
    vector<vector<edge<ll>>> g(n);
    rep(i, 1, n) {
        int x, c;
        cin >> x >> c;
        x--;
        g[x].emplace_back(i, c);
    }
    HLD hld(g);
    hld.build();

    ll sum = 0;
    int now = 0;
    cout << sum << "\n";

    vector<map<int, int>> G(n);
    set<int> s;
    s.emplace(0);

    vl par(n);

    auto connect = [&](int i, int j) {
        if(i == j) return;
        if(hld.d[i] < hld.d[j]) swap(i, j);
        par[i] = j;
        G[i][hld.jump(i, j, 1)] = j;
        G[j][hld.jump(j, i, 1)] = i;
    };

    vl d(n), dp(n);
    auto dfs = [&](auto &&f, int x) -> void {
        fore(e, g[x]) {
            d[e] = d[x] + e.cost;
            dp[e] = e.cost;
            f(f, e);
        }
    };
    dfs(dfs, 0);
    auto dist = [&](int i, int j) {
        int u = hld.lca(i, j);
        return d[i] + d[j] - d[u] * 2;
    };

    auto f = [](int x, int y) { return x + y; };
    segtree<int, decltype(f)> seg(n, f, 0);
    seg.update(0, 1);

    rep(i, 1, n) {
        int idx = hld.in[i];
        auto it = s.lower_bound(idx);
        int mypar;
        if(it != begin(s)) {
            int l = hld.rev[*prev(it)];
            int c = hld.lca(l, i);
            mypar = c;
            if(s.insert(hld.in[c]).second) {
                l = hld.rev[*s.lower_bound(hld.in[hld.jump(c, l, 1)])];
                int p = par[l];
                connect(l, c);
                connect(i, c);
                connect(c, p);
            }
        }
        if(it != end(s)) {
            int r = hld.rev[*it];
            int c = hld.lca(r, i);
            if(hld.d[mypar] < hld.d[c]) mypar = c;
            if(s.insert(hld.in[c]).second) {
                r = hld.rev[*s.lower_bound(hld.in[hld.jump(c, r, 1)])];
                int p = par[r];
                connect(r, c);
                connect(i, c);
                connect(c, p);
            }
        }
        connect(i, mypar);
        s.insert(hld.in[i]);
        seg.update(hld.in[i], 1);
        sum += dist(now, i);

        int cnt = 0;

        while(true) {
            if(now == i) break;
            int nxt = hld.jump(now, i, 1);
            int to = G[now][nxt];
            ll s;
            if(hld.d[nxt] > hld.d[now]) {
                s = seg.prod(hld.in[to], hld.out[to]);
            } else
                s = (i + 1) - seg.prod(hld.in[now], hld.out[now]);
            if(s > (i + 1 - s)) {
                sum -= dist(now, to) * (s * 2 - (i + 1));
                now = to;
            } else {
                break;
            }
        }
        cout << sum << '\n';
    }
}