(re-root) 834. Sum of Distances in Tree

T: O(n)
S: O(n)
```java
class Solution {
    public int[] sumOfDistancesInTree(int n, int[][] edges) {
        List<List<Integer>> tree = buildTree(n, edges);
        int[] subTreeCount = new int[n];
        int[] result = new int[n];
        calSubTreeCount(tree, subTreeCount, 0, -1);
        result[0] = dfs0(tree, subTreeCount, 0, -1);
        dfs(tree, subTreeCount, result, 0, -1, n);
        return result;
    }
    private int calSubTreeCount(List<List<Integer>> tree, int[] subTreeCount, int current, int parent) {
        int sum = 1;
        for (int next : tree.get(current)) {
            if (next == parent) {
                continue;
            }
            sum += calSubTreeCount(tree, subTreeCount, next, current);
        }
        subTreeCount[current] = sum;
        return sum;
    }
    private int dfs0(List<List<Integer>> tree, int[] subTreeCount, int current, int parent) {
        int sum = 0;
        for (int next : tree.get(current)) {
            if (next == parent) {
                continue;
            }
            sum += dfs0(tree, subTreeCount, next, current);
        }
        sum += subTreeCount[current]-1; // not include itself
        return sum;
    }
    private void dfs(List<List<Integer>> tree, int[] subTreeCount, int[] result, int current, int parent, int n) {
        int sum = 0;
        for (int next : tree.get(current)) {
            if (next == parent) {
                continue;
            }
            result[next] = result[current] + n - 2*subTreeCount[next];
            dfs(tree, subTreeCount, result, next, current, n);
        }
    }
    private List<List<Integer>> buildTree(int n, int[][] edges) {
        List<List<Integer>> tree = new ArrayList<>();
        for (int i = 0; i < n ; i++) {
            tree.add(new ArrayList<>());
        }
        for (int[] edge : edges) {
            tree.get(edge[0]).add(edge[1]);
            tree.get(edge[1]).add(edge[0]);
        }
        return tree;
    }
}

/**
re-root:

根據 re-root 的概念, 可以推出下面的公式
f(2) = f(0) + a - b = f(0) + n - 2b
b = subtree(2)
a = n - b

0. 先建立圖, 因為是 tree, 所以之後的技巧就是傳入 parent, 來判斷是不是都是由上往下

1. calSubTreeCount -> T: O(n)
所以需要算出 subtree 的 path count (include root, itself)
2. dfs 算出 f(0) -> T: O(n)

3. dfs 利用公式, 算出其它結果 -> T: O(n)

T: O(n)
S: O(n)
 */
```

Last updated