Distance Sum (AOJ2636)
この記事は帰ってきた AOJ-ICPC Advent Calendar 2022 17 日目の記事です。
木クエリゴリゴリ問題です。
問題概要
頂点の木がある。頂点 の親は であり、その間の距離は である。 について、 を求めよ。
制約
解法
大体重心を求めたいという問題です。 頂点のうち、 頂点しか気にならないのでそれらからなる圧縮木を作ってみます。すると、重心の性質から求める はその圧縮木上の頂点としてよいことが分かります。(特に、 が簡単にわかる。)
さて、重心はある頂点から部分木のサイズが 以上の方向に移動することを繰り返すことで求まる。また、圧縮木上でも普通の木と変わらず、移動すると必ず部分木のサイズが 1 以上減る(不変だとすると、その頂点は圧縮木に不必要)
このことから、圧縮木の頂点を一つ増やしたとき、重心は高々一回移動するのみである。移動した際の答えの変化はセグ木等で適当に求まるので、圧縮木を動的に変化させることを考える。
圧縮木は、オイラーツアー順に頂点を見たときに隣合った頂点の LCA を全て追加することで求まる。 そのため、新しい頂点を追加する際に既にみた頂点のうちオイラーツアー順で隣接しているものとの LCA が追加する頂点の候補である。 今、頂点 を追加するとき、オイラーツアー順で手前の頂点を とし、 とする。 が圧縮木に含まれていないとき、 の元々の(圧縮木上での)親を とし、 と結びなおす。 これをオイラーツアー順で次に隣接している頂点 についても行う。 の親は に対する のうち、深い方である。
が圧縮木上にまだないことから、 の子供は高々一つしかない。よって、 と子供を結ぶ辺、 と親を結ぶ辺が正しく貼れる。
さて、この圧縮木をどう管理するかだが、HLD と map を用いてグラフを管理することが出来る。圧縮木上で の辺を貼る際、 から に向かって 1 進んだ頂点を とし、 G[U][W] = V として管理すればよい。後は頑張るとなんとかなる。
この持ち方をしていれば、「辺を切る」という操作は必要なくなるので楽(辺を切る際、必ず新しいところと結ぶため)
実装
#include <bits/stdc++.h> using namespace std; #define ov3(a, b, c, name, ...) name #define ll long long #define rep2(i, a, b) for(ll i = (a); i < (b); i++) #define rep1(i, n) rep2(i, 0, n) #define rep0(n) rep1(iiiii, n) #define rep(...) ov3(__VA_ARGS__, rep2, rep1, rep0)(__VA_ARGS__) #define si(c) (ll)(c.size()) #define pll pair<ll, ll> #define vl vector<ll> #define all(v) v.begin(), v.end() #define fore(e, v) for(auto &&e : v) template <typename S, typename T> bool chmax(S &x, const T &y) { return (x < y ? x = y, true : false); } template <typename S, typename T> bool chmin(S &x, const T &y) { return (x > y ? x = y, true : false); } template <typename T> ostream &operator<<(ostream &os, const vector<T> &v) { os << "{"; rep(i, si(v)) { if(i) os << ", "; os << v[i]; } return os << "}"; } template <typename T, typename F> struct segtree { int n, rn; vector<T> seg; F f; T id; segtree() = default; segtree(int rn, F f, T id) : rn(rn), f(f), id(id) { n = 1; while(n < rn) n <<= 1; seg.assign(n * 2, id); } void set(int i, T x) { seg[i + n] = x; } void build() { for(int k = n - 1; k > 0; k--) seg[k] = f(seg[k * 2], seg[k * 2 + 1]); } void update(int i, T x) { set(i, x); i += n; while(i >>= 1) seg[i] = f(seg[i * 2], seg[i * 2 + 1]); } T get(int i) { return seg[i + n]; } T prod(int l, int r) { T L = id, R = id; for(l += n, r += n; l < r; l >>= 1, r >>= 1) { if(l & 1) L = f(L, seg[l++]); if(r & 1) R = f(seg[--r], R); } return f(L, R); } template <typename C> int max_right(int l, const C &c) { assert(c(id)); if(l >= rn) return rn; l += n; T sum = id; do { while((l & 1) == 0) l >>= 1; if(!c(f(sum, seg[l]))) { while(l < n) { l <<= 1; auto nxt = f(sum, seg[l]); if(c(nxt)) { sum = nxt; l++; } } return l - n; } sum = f(sum, seg[l++]); } while((l & -l) != l); return rn; } }; template <typename G> struct HLD { int n; G &g; vector<int> sub, in, out, head, rev, par, d; HLD(G &g) : n(si(g)), g(g), sub(n), in(n), out(n), head(n), rev(n), par(n), d(n) {} void dfs1(int x, int p) { par[x] = p; sub[x] = 1; if(g[x].size() and g[x][0] == p) swap(g[x][0], g[x].back()); fore(e, g[x]) { if(e == p) continue; d[e] = d[x] + 1; dfs1(e, x); sub[x] += sub[e]; if(sub[g[x][0]] < sub[e]) swap(g[x][0], e); } } void dfs2(int x, int p, int &t) { in[x] = t++; rev[in[x]] = x; fore(e, g[x]) { if(e == p) continue; head[e] = (g[x][0] == e ? head[x] : e); dfs2(e, x, t); } out[x] = t; } void build() { int t = 0; head[0] = 0; dfs1(0, -1); dfs2(0, -1, t); } int la(int v, int k) { while(1) { int u = head[v]; if(in[v] - k >= in[u]) return rev[in[v] - k]; k -= in[v] - in[u] + 1; v = par[u]; } } int lca(int u, int v) { for(;; v = par[head[v]]) { if(in[u] > in[v]) swap(u, v); if(head[u] == head[v]) return u; } } template <typename T, typename Q, typename F> T query(int u, int v, const T &e, const Q &q, const F &f, bool edge = false) { T l = e, r = e; for(;; v = par[head[v]]) { if(in[u] > in[v]) swap(u, v), swap(l, r); if(head[u] == head[v]) break; l = f(q(in[head[v]], in[v] + 1), l); } return f(f(q(in[u] + edge, in[v] + 1), l), r); } int dist(int u, int v) { return d[u] + d[v] - 2 * d[lca(u, v)]; } int jump(int s, int t, int i) { if(!i) return s; int l = lca(s, t); int dst = d[s] + d[t] - d[l] * 2; if(dst < i) return -1; if(d[s] - d[l] >= i) return la(s, i); i -= d[s] - d[l]; return la(t, d[t] - d[l] - i); } }; template <typename T> struct edge { int to; T cost; edge(int to, T cost) : to(to), cost(cost) {} edge &operator=(const int &x) { to = x; return *this; } operator int() const { return to; } }; int main() { cin.tie(0), cout.tie(0); ios::sync_with_stdio(false); int n; cin >> n; vector<vector<edge<ll>>> g(n); rep(i, 1, n) { int x, c; cin >> x >> c; x--; g[x].emplace_back(i, c); } HLD hld(g); hld.build(); ll sum = 0; int now = 0; cout << sum << "\n"; vector<map<int, int>> G(n); set<int> s; s.emplace(0); vl par(n); auto connect = [&](int i, int j) { if(i == j) return; if(hld.d[i] < hld.d[j]) swap(i, j); par[i] = j; G[i][hld.jump(i, j, 1)] = j; G[j][hld.jump(j, i, 1)] = i; }; vl d(n), dp(n); auto dfs = [&](auto &&f, int x) -> void { fore(e, g[x]) { d[e] = d[x] + e.cost; dp[e] = e.cost; f(f, e); } }; dfs(dfs, 0); auto dist = [&](int i, int j) { int u = hld.lca(i, j); return d[i] + d[j] - d[u] * 2; }; auto f = [](int x, int y) { return x + y; }; segtree<int, decltype(f)> seg(n, f, 0); seg.update(0, 1); rep(i, 1, n) { int idx = hld.in[i]; auto it = s.lower_bound(idx); int mypar; if(it != begin(s)) { int l = hld.rev[*prev(it)]; int c = hld.lca(l, i); mypar = c; if(s.insert(hld.in[c]).second) { l = hld.rev[*s.lower_bound(hld.in[hld.jump(c, l, 1)])]; int p = par[l]; connect(l, c); connect(i, c); connect(c, p); } } if(it != end(s)) { int r = hld.rev[*it]; int c = hld.lca(r, i); if(hld.d[mypar] < hld.d[c]) mypar = c; if(s.insert(hld.in[c]).second) { r = hld.rev[*s.lower_bound(hld.in[hld.jump(c, r, 1)])]; int p = par[r]; connect(r, c); connect(i, c); connect(c, p); } } connect(i, mypar); s.insert(hld.in[i]); seg.update(hld.in[i], 1); sum += dist(now, i); int cnt = 0; while(true) { if(now == i) break; int nxt = hld.jump(now, i, 1); int to = G[now][nxt]; ll s; if(hld.d[nxt] > hld.d[now]) { s = seg.prod(hld.in[to], hld.out[to]); } else s = (i + 1) - seg.prod(hld.in[now], hld.out[now]); if(s > (i + 1 - s)) { sum -= dist(now, to) * (s * 2 - (i + 1)); now = to; } else { break; } } cout << sum << '\n'; } }