973. K Closest Points to Origin

heap

left(0, 1) to right(2, 3) .=> √(x1 - x2)^2 + (y1 - y2)^2)

maintain max heap, over size k, pop max value

at last, remain top K smallest elements

time: O(nlogk), add to heap is O(logk), do n times

space: O(k)

// O(NlogK)
// O(k)
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        // p2 - p1 
        PriorityQueue<int[]> heap = new PriorityQueue<int[]>(
            (p1, p2) -> getDistance(p2) - getDistance(p1));

        for (int[] point: points) {
            heap.offer(point);
            if (heap.size() > k) {
                heap.poll();
            }
        }

        return heap.toArray(new int[k][2]);        
    }

    private int getDistance(int p[]) {
        return p[0]*p[0] + p[1]*p[1];
    }
}

quickselect with radompivot

time: O(n), worst O(n^2)

space: O(1)

class Solution {
    public int[][] kClosest(int[][] points, int k) {
        findKthSmallest(points, 0, points.length - 1, k);
        
        return Arrays.copyOf(points, k); // or Arrays.copyOfRange(points, 0, K);
    }
    
    private void findKthSmallest(int[][] nums, int left, int right, int ksmallest) {
        if (left == right) return; // only one element
        
        Random random = new Random();
        int randomPivot = left + random.nextInt(right - left);
        int pivot = partition(nums, left, right, randomPivot);
        
        if (ksmallest == pivot) { // bingo!
            return;
        } else if (ksmallest < pivot) { // find left side
            findKthSmallest(nums, left, pivot - 1, ksmallest);
                
        } else { // find right side
            findKthSmallest(nums, pivot + 1, right, ksmallest);
        }
    }
    
    private int partition(int[][] nums, int left, int right, int pivot) {
        int pivotVal[] = nums[pivot];
        swap(nums, pivot, right);
        int storedIndex = left;
        
        for (int i = left; i < right; i++) {
            if (getDistance(nums[i]) < getDistance(pivotVal)) {
                swap(nums, storedIndex, i);
                storedIndex++;
            }
        }
        swap(nums, storedIndex, right);
        return storedIndex;
    }
    
    private void swap(int nums[][], int i, int j) {
        int temp[] = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }
    
    private int getDistance(int p[]) {
        return p[0]*p[0] + p[1]*p[1];
    }
}

quick select (while version)

    // test
    // https://leetcode.com/problems/k-closest-points-to-origin/discuss/218691/O(N)-Java-using-Quick-Select(beats-100)

    // Theoretically, the average time complexity is O(N) ,
    // but just like quick sort, in the worst case,
    // this solution would be degenerated to O(N^2), and pratically,
    // the real time it takes on leetcode is 15ms.
    
    // O(N)
    public int[][] kClosestByQuickSelect(int[][] points, int K) {
        int len =  points.length, l = 0, r = len - 1;
        while (l <= r) {
            int mid = helper(points, l, r); //取得pivot: middle
            if (mid == K) break;
            if (mid < K) { // 我們需要k個, 所以要調整pivot的位置
                l = mid + 1;
            } else {
                r = mid - 1;
            }
        }
        return Arrays.copyOfRange(points, 0, K);
    }

    private int helper(int[][] A, int l, int r) {
        int[] pivot = A[l]; // 以l index 為 pivot
        while (l < r) {

            // right 比 pivot 大於等於0時, 繼續走訪 r-- (正確的), 小於時 停下, 對調
            while (l < r && compare(A[r], pivot) >= 0) r--;
            A[l] = A[r];

            // left 比 pivot 小於等於0時, 繼續走訪 l++ (正確的), 大於時 停下, 對調
            while (l < r && compare(A[l], pivot) <= 0) l++;
            A[r] = A[l];
        }
        // left 設為 pivot, 回傳 l (新的 pivot index)
        A[l] = pivot;
        return l;
    }

    private int compare(int[] p1, int[] p2) {
        return p1[0] * p1[0] + p1[1] * p1[1] - p2[0] * p2[0] - p2[1] * p2[1];
    }

Last updated