803. Bricks Falling When Hit (backward union find)

class Solution {
    /*
        T: O(mn)
        S: O(mn)
    */
    private static final int[][] DIRS = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}};
    private static final int WALL = 0;
    public int[] hitBricks(int[][] grid, int[][] hits) {
        int m = grid.length;
        int n = grid[0].length;
        UnionFind uf = new UnionFind(n*m + 1);
        
        for (int[] hit : hits) {
            int x = hit[0];
            int y = hit[1];
            if (grid[x][y] == 1) { // if there is really a brick here, we hit!
                grid[x][y] = 2; // mark hit to 2
            }
        }
        
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                if (grid[i][j] == 1) {
                    unionAround(grid, i, j, uf);
                }
            }
        }
        
        int[] droppedBricks = new int[hits.length];
        int leftBricks = uf.getSize(WALL);
        
        for (int i = hits.length - 1; i >= 0; i--) {
            int x = hits[i][0];
            int y = hits[i][1];
            
            if (grid[x][y] == 2) { // ้‚„ๆ˜ฏๅพ—ๅˆคๆ–ท, ๅ› ็‚บไน‹ๅ‰ๆœƒๆœ‰ๅ‡็š„ hit
                grid[x][y] = 1;
                unionAround(grid, x, y, uf);
            
                int newLeftBricks = uf.getSize(WALL);
                droppedBricks[i] = Math.max(newLeftBricks - leftBricks - 1, 0);
                
                // notice here, remember update left bricks
                leftBricks = newLeftBricks;
            }
        }
        
        return droppedBricks;
    }
    
    
    
    private int toId(int i, int j, int n) { // because wall is 0, so other bricks should add 1 more
        return i*n + j + 1;
    }
    
    private void unionAround(int[][] grid, int i, int j, UnionFind uf) {
        int m = grid.length;
        int n = grid[0].length;
        
        int cur = toId(i, j, n);
        for (int[] dir : DIRS) {
            int x = i + dir[0];
            int y = j + dir[1];
            if (x >= 0 && x < m && y >= 0 && y < n && grid[x][y] == 1) {
                uf.union(cur, toId(x, y, n));
            } 
        }
        if (i == 0) {
            uf.union(cur, 0);
        }
    }
    
    class UnionFind {
        int[] parent;
        int[] size;
        UnionFind(int n) {
            parent = new int[n];
            size = new int[n];
            Arrays.fill(size, 1);
            for (int i = 0; i < n; i++) {
                parent[i] = i;
            }
        }
        public int find(int x) {
            if (parent[x] == x) {
                return x;
            }
            return parent[x] = find(parent[x]);
        }
        
        public void union(int i, int j) {
            int x = find(i);
            int y = find(j);
            if (x != y) {
                parent[x] = y;
                size[y] += size[x];
            }
        }
        public int getSize(int x) {
            return size[find(x)];
        }
    }
}

Last updated