Top K problem – Sort, Heap, and Quickselect

Top K problem refers to those asking to find the Kth largest/smallest element or the top K largest/smallest items from an unsorted array. Sorting first and returning the Kth item definitely works, yet it yields O(n^2) or O(nlogn) time complexity. With heap or quickselect, O(n) is achievable.

Let’s take Leetcode 215. Kth Largest Element in an Array as an example.

Find the kth largest element in an unsorted array. Note that it is the kth largest element in the sorted order, not the kth distinct element.
Example 1:
Input: [3,2,3,1,2,4,5,5,6] and k = 4 Output: 4
Note: 
You may assume k is always valid, 1 ≤ k ≤ array's length.

Note: This post assumes knowledge on heap-related operations and Quicksort. Detailed info about them is available here: Visualized + Python! Intro to Sorting Algorithms – Selections Sort, Insertion Sort, Bucket Sort, Bubble Sort, Mergesort, Quicksort, Heapsort and their complexity analysis.

Sort – O(nlogn)

Various sorting algorithms could be applied first before returning the Kth item. Time complexities would be O(n^2) or O(nlogn).

Note: bucket sort could reach O(n)

sorted(nums,reverse=True)[k-1]

Heap – O(nlogk)

We can implement a size-K Min-heap, and heap[0] would be the answer. More specifically, we start with an empty heap, and insert() items into it before the size equals to K. Then for the remaining items, we ignore those who are smaller than heap[0] as they won’t contribute to the top K largest elements, and as for the item that is larger than heap[0], we would switch it with heap[0] first and heapify(0) to maintain the Min-heap structure.

The height of the heap is O(logk), therefore the total time complexity would be O(nlogk) since we need to insert() n times.

        heap = []
        def heapify(i):
            nonlocal heap
            if i >= len(heap):
                return
            c1 = 2 * i + 1
            c2 = 2 * i + 2
            maxi = i
            if c1 < len(heap) and heap[c1] < heap[maxi]:
                maxi = c1
            if c2 < len(heap) and heap[c2] < heap[maxi]:
                maxi = c2
            if maxi != i:
                heap[i],heap[maxi] = heap[maxi],heap[i]
                heapify(maxi)
                
        def insert(num):
            nonlocal heap
            heap += [num]
            cur = len(heap) - 1
            par = (cur-1)//2
            while par >= 0 and heap[par] > heap[cur]:
                heap[cur],heap[par] = heap[par],heap[cur]
                cur = par
                par = (cur-1)//2
        for i in range(len(nums)):
            if len(heap) < k : # before size == K
                insert(nums[i]) 
            elif nums[i] > heap[0]: # after size == K
                heap[0] = nums[i]
                heapify(0)
        return heap[0]

Quickselect – O(n)

Quickselect utilises partition() from Quicksort. As described in Leetcode:

One chooses a pivot and defines its position in a sorted array in a linear time using so-called partition algorithm.

However, instead of recursing both (left,pivot-1) and (pivot+1,right) after getting pivot from calling partition(left,right), it discusses relations between K and pivot: if K < pivot, then the Kth largest item must be in (left,pivot-1); else if K > pivot, then the Kth largest item must be in (pivot+1,right); else if K == pivot, we get the answer!

If it’s quicksort, it would take O(nlogn). However, this approach keeps throwing away one half (in average case) of the array, which leads to O(n) time complexity in the average case, though it might reach O(n^2) in the worst case due to the instability of quicksort. (Wikipedia)

        left, right = 0, len(nums)-1
        def partition(l,r):
            pivot = l
            while l < r:
                while l <= r and nums[l] >= nums[pivot]:
                    l += 1
                while l <= r and nums[r] <= nums[pivot]:
                    r -= 1
                if l < r:
                    nums[l],nums[r] = nums[r],nums[l]
            nums[r],nums[pivot] = nums[pivot],nums[r]
            return r
        while True:
            p = partition(left,right)
            if p == k-1:
                return nums[p]
            if p > k-1:
                right = p-1
            else:
                left = p+1

Related Leetcode practices

347. Top K Frequent Elements
692. Top K Frequent Words
1471. The k Strongest Values in an Array

Feel free to leave a comment here :p

Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments