Skip to content

The k-way Merge problem

Merging k sorted arrays

We can think of the k-way merge problem as a generalization of the 2-way merge problem. That is given, k sorted containers, return a merged sorted container.

In the 2-way merge problem , we had a binary choice when picking the next value to write to the merged container. Now we have a k-way choice among k or less elements at each step. We basically need a good way to pick the smallest value from among all possible choices. This is what a heap does. Since we are considering containers sorted in a non decreasing order, a Min Heap is what we need to dynamically pick the smallest value at each step.

The crux of the algorithm is :

Algorithm

initialize a Min heap of all possible choices. 
While there is something in the heap : 
   Pop the heap top and write it to the merged container.
   Get the element next to popped element in the container to which it belongs (If it exists). Put it in the heap.

Thats it.

The code should actually simpler than handling the 2-way merge because there are no if else conditions to handle unequal lengthed arrays.

Code
"""
The same problem but with a different signature :
https://www.geeksforgeeks.org/problems/merge-k-sorted-arrays/1
"""
#User function Template for python3
from typing import List
from heapq import heapify, heappop, heappush
class Solution:
    #Function to merge k sorted arrays.
    def mergeKArrays(self, arr : List[List[int]]) -> List[int]:
        """
        Given k sorted arrays, return a single merged array.
        Input :
         arr : a list of integer arrays
        """
        # return merged list
        len_res = 0

        #Priority Queue- min heap
        pq = []
        #Assume that arrays can be of unequal length
        #Initialize min heap
        for array in arr :
            len_res += len(array)
            #Store the current index and a reference to the array for getting the next value in the array
            pq.append((array[0] , 0 , array))

        #heapify pq
        heapify(pq)

        #allocate result array
        res = [None] * len_res

        #write pointer
        write_ptr = 0 
        #Pick the lowest value from the current choices and write it to res
        while pq :
            curr_val, curr_ptr, curr_arr = heappop(pq)
            res[write_ptr] = curr_val
            if curr_ptr < len(curr_arr) -1  : 
                heappush(pq, (curr_arr[curr_ptr+1] , curr_ptr+1,curr_arr ))
            write_ptr+=1
        return res

if __name__ == "__main__" :
    print( Solution().mergeKArrays([[1,2,3] , [5,6] , [9,9,10]]) )

Merging k sorted linked lists

The same thing for linked lists.

Code
from heapq import heapify,heappush,heappop
#Definition for singly-linked list.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next
class Solution:
    def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
        pq=[]
        for idx, head in enumerate(lists) : 
            if not head : continue
            #Put idx between the value and the node object  to avoid comparing objects in case of a tie 
            #Which will lead to this error :'<' not supported between instances of 'ListNode' and 'ListNode'
            heappush(pq,(head.val,idx, head)) 

        #Allocate Write container and write pointer
        dummy = ListNode(0)
        write_ptr = dummy

        while pq :
            #Pick the lowest valued node from available choices.
            _,idx,curr_node = heappop(pq)

            #write to write container and move write pointer ahead
            write_ptr.next = curr_node
            write_ptr = write_ptr.next

            #For curr_node, put its next node in the heap  
            nxt_node = curr_node.next 
            if nxt_node : heappush(pq,(nxt_node.val, idx, nxt_node ))
        return dummy.next

Practice

Kth Smallest Element in a Sorted Matrix
"""
Solution for https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/

This can be solved using k-way merge. Instead of putting the next smaller value in a resultant merged array, 
simply put it in a max heap of size k.

Example :
[
[1,5,9],
       *
[10,11,13],
    *
[12,13,15]
 *
]

say , k = 4
Can we use A Max heap to maintain the k largest numbers?
[10,9,5,1]

When size becomes greater than k pop from the max heap:

>[11,10,9,5,1]
>size is greater than 5 
>pop from max heap
[10,9,5,1]

Memory required is : max heap of size k to  keep track of k smallest numbers
A min heap of size len(matrix) (equal to the number of rows in matrix) to pick the next smallest number

Algorithm :

1) Initialize a min heap to pick next smallest number from all rows of matrix
2) While min heap is not empty:
    -  Pop from min heap and put in a max heap of size k
    -  Put element next to popped element into min heap  
3) Return the top of max heap.

"""
from heapq import heapify, heappush , heappop
class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        pq = []
        for row in matrix : 
            e = (row[0], 0, row)  # tuple of Curr_Value, curr_index, row reference
            pq.append(e)
        heapify(pq)

        max_hp = []
        while pq :
            #pick next smallest number
            val, idx, row = pq.pop()
            #Simulate max heap by negating values
            heappush(max_hp,-val)
            if len(max_hp) > k :
                heappop(max_hp)
            if idx < len(row) -1 :
                heappush(pq,(row[idx+1],idx+1,row))
        # Retun top of max heap
        return -heappop(max_hp) 
Find K Pairs with Smallest Sums
"""
Solution for : https://leetcode.com/problems/find-k-pairs-with-smallest-sums/

The problem statement is confusing. Key informations is that all possible pairs are formed by 
taking one number from first array and combining it with a second number from the second array.
        [1,7,11]
        [2,4,6]
 [ (1,2) , (1,4) (1,6) ,(7,2) , (7,4) , (7,6) , (11,1), (11,2) , (11,7) ]

The problem is asking us , if we sort the above list of tuples by their sum, what would be the first k tuples.

Thoughts :

Do we have to enumerate all pairs? Maybe. If we do that, we can maintain a max heap fo size k. 
In the max heap, keep tuples ( (sum of two elements), (el1, el2))

We can do this :

from heapq import heappop, heappush, heapify 
class Solution:
    def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
        res = []
        ## Quadratic Time Complexity.
        for n1 in nums1 :
            for n2 in nums2 :
                heappush(res,(-1*(n1+n2),(n1,n2)))
                if len(res) > k :
                    heappop(res)

        ans = []
        while res : 
            ans.append(heappop(res)[1])
        return ans

But this has a  quadratic run time and will give TLE.

Can we do better ? 
Why yes we can!


Consider :
[1,1,2], nums2 = [1,2,3], k = 3

Isn't this how we are enumerating pairs?
1,1 -> 1,2 -> 1,3
*
1,1 -> 1,2 -> 1,3
*
2,1 -> 2,2 -> 2,3
* 

These look an awful lot like three linked lists. 
We know how to get the next element for each linked list. 
To get the n smallest pairs we just need to run a k-way merge n times. 
(Lets call the number of pairs required n to avoid confusion with the concept of k-way merge.
In k-way merge, k refers to the number of linked lists we are merging )

Lets implement the solution. 
"""

from heapq import heappop,heappush

class Solution:
    def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
        n = k # rename k to n to avoid confusion with out mental model
        pq = []
        #Init pq 
        #Each linked list Node is a tuple of two elements.
        #In each tuple the first element is fixed and from the first array. We take second element from the second array.
        #We will use the index of the first element as tie breaker
        #Use index of second element to get the next element in the same linked list

        idx_2 = 0 
        for idx_1, val in enumerate(nums1) :
            heappush(pq,( ( nums1[idx_1] + nums2[0] ,idx_1, idx_2) 
            ))#  (sum of two elements, index of first element, index of second element )   

        res = []

        # Pick the next smallest tuple k times
        while pq and n :
            n-=1
            _,idx_1,idx_2 = heappop(pq) # idx_2 is the index of element2 in the tuple
            res.append((nums1[idx_1],nums2[idx_2])
            )
            #Put the element next to popped element in the min heap
            if idx_2 < len(nums2)-1 :
                heappush(pq,( ( nums1[idx_1] + nums2[idx_2+1] ,idx_1, idx_2+1) )
                 ) #  (sum of two elements, index of first element, index of second element) 
        return res