Find Kth Smallest Pair Distance - Analysis
Asked Answered
F

2

17

Question:

This is a problem from LeetCode:

Given an integer array, return the k-th smallest distance among all the pairs. The distance of a pair (A, B) is defined as the absolute difference between A and B.

Example:

Input:
nums = [1,3,1]
k = 1
Output: 0 
Explanation:
Here are all the pairs:
(1,3) -> 2
(1,1) -> 0
(3,1) -> 2
Then the 1st smallest distance pair is (1,1), and its distance is 0.

My Problem

I solved it with a naive approach O(n^2) basically I find all distances and sort it then find the kth smallest. Now here is a better Solution. It is not my code I found it on the discussion forum on leetcode. But I am having trouble understanding a crucial part of the code.

The code below is basically doing a binary search. the low is the min distance and high is the max distance. calculate a mid like usual binary search. then it does countPairs(a, mid) to find number of pairs with absolute difference less than or equal to mid. then adjust low and high accordingly.

But WHY the binary Search result MUST be one of the distances? At first, low and high are got from the array, but the mid, is calculated by them, it may not be the distance. In the end we are returning low which the values changes during the binary search base on mid + 1. Why is mid + 1 guarantee to be one of the distance?

class Solution {
    // Returns index of first index of element which is greater than key
    private int upperBound(int[] a, int low, int high, int key) {
        if (a[high] <= key) return high + 1;
        while (low < high) {
            int mid = low + (high - low) / 2;
            if (key >= a[mid]) {
                low = mid + 1;
            } else {
                high = mid;
            }
        }
        return low;
    }

    // Returns number of pairs with absolute difference less than or equal to mid.
    private int countPairs(int[] a, int mid) {
        int n = a.length, res = 0;
        for (int i = 0; i < n; i++) {
            res += upperBound(a, i, n - 1, a[i] + mid) - i - 1;
        }
        return res;
    }

    public int smallestDistancePair(int a[], int k) {
        int n = a.length;
        Arrays.sort(a);

        // Minimum absolute difference
        int low = a[1] - a[0];
        for (int i = 1; i < n - 1; i++)
            low = Math.min(low, a[i + 1] - a[i]);

        // Maximum absolute difference
        int high = a[n - 1] - a[0];

        // Do binary search for k-th absolute difference
        while (low < high) {
            countPairs(a, mid)
            if (countPairs(a, mid) < k)
                low = mid + 1;
            else
                high = mid;
        }

        return low;
    }
}
Fleabitten answered 11/2, 2018 at 19:16 Comment(0)
A
5

This type of binary search will find the first value x for which countPairs(a,x) >= k. (The topcoder tutorial explains this well.)

Therefore when the function terminates with final value low, we know that the number of pairs changes when the distance changes from low-1 to low, and therefore there must be a pair with distance low.

For example, suppose we have a target of 100 and know that:

countPairs(a,9) = 99
countPairs(a,10) = 100

There must be a pair of numbers with distance exactly 10, because if there was no such pair, then the number of pairs with distance less than or equal to 10 would be the same as the number of pairs with distance less than or equal to 9.

Note that this only applies because the loop is run until the interval under test is completely exhausted. If the code had instead used an early termination condition that quit the loop if the exact target value was found, then it could return incorrect answers.

Accordion answered 11/2, 2018 at 20:48 Comment(6)
I am aware that this type of binary search finds the first value X for which countPairs(a,x) >= k. (we only update low when countPairs(a, mid) < k). However, why is the value low that function terminates with must be a pair with distance low? I don't understand your logic in your second paragraph. Please elaborate more?Fleabitten
I've added an example, it may help to consider trying to construct a counter example.Accordion
your example is helpful and I understand it becuase countPairs(a,x) and countPairs(a,x + 1) is 1 apart, that doesn't seem like the case in binary serach we are jumping more then 1 at a time. following your example. Say k= 100, lo = 0 , hi = 18. and we have mid = 9. since countPairs(a,9) = 99 < 100 we will update lo = mid + 1. Now we have k = 100, lo = 10, hi = 18. the next time we search we are looking at countPairs(a,14) not countPairs(a,10).Fleabitten
I agree with everything you say, but you need to carry on a few more iterations. We will look at countPairs(a,14) and set high to 14. Then we will look at countPairs(a,12) and set high to 12. Then we will look at countPairs(a,11) and set high to 11. Finally we will look at countPairs(a,10) and set high to 10. The point is that the binary search finds the first place where countPairs(a,x) is at the target, which in your example will be when x is equal to 10.Accordion
ah I see so lo and hi will eventually be 1 apart. at that point mid will be lo. so we will check countPairs(a,lo) = x. If x is less than k we will return mid + 1 which is hi in the last iteration. Since we know countPairs(a,lo) < k and countPairs(a,hi) >= k from the last iteration right? I think I almost get it? but how do we know where lo and hi meet which is the final mid + 1 value is guarantee to be a distance created by a pair? (I think I got confused as I type this paragraph I thought I understood it for a moment)Fleabitten
Amazing tutorial.Spiccato
S
0

Just out of interest, we can solve this problem in O(n log n + m log m) time, where m is the range, using a Fast Fourier Transform.

First sort the input. Now consider that each of the attainable distances between numbers can be achieved by subtracting one difference-prefix-sum from another. For example:

input:            1 3 7
diff-prefix-sums:  2 6
difference between 7 and 3 is 6 - 2

Now let's add the total (the rightmost prefix sum) to each side of the equation:

ps[r] - ps[l]       = D
ps[r] + (T - ps[l]) = D + T

Let's list the differences:

1 1 3
 0 2

and the prefix sums:

p     => 0 0 2
T - p => 2 2 0  // 2-0, 2-0, 2-2

We need to efficiently determine and order the counts of all the different achievable differences. This is akin to multiplying the polynomial with coefficients [1, 0, 2] by the polynomial with coefficients, [2, 0, 0] (we don't need the zero coefficient in the second set since it only generates degrees less than or equal to T), which we can accomplish in m log m time, where m is the degree, with a Fast Fourier Transform.

The resultant coefficients would be:

  1 0 2
* 
  2 0 0
=> 
  x^2 + 2
*
  2x^2

= 2x^4 + 4x^2

=> 2 0 4 0 0

We discard counts of degrees lower than T, and display our ordered results:

2 * 4 = 2 * (T + 2) => 2 diffs of 2
4 * 2 = 4 * (T + 0) => 4 diffs of 0

We overcounted diffs of 0. Perhaps there is a convenient way to calculate the zero overcount someone could suggest. I spent some time but haven't yet distinguished one.

In any case, the count of zero differences are readily available using the disjoint duplicate counts, which allows us to still return the kth difference in O(n log n + m log m) total time.

Sketchbook answered 12/2, 2018 at 13:16 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.