Disjoint Set – Union & Find

In computer science, a disjoint-set data structure … is a data structure that stores a collection of disjoint (non-overlapping) sets. … It provides operations for adding new sets, merging sets (replacing them by their union), and finding a representative member of a set.

Source: https://en.wikipedia.org/wiki/Disjoint-set_data_structure

Disjoint Set helps to group distinct elements into a collection of disjoint sets. There are two major functions associated with it: finding the set that a given element belongs to and merging two sets into one (Cormen, Thomas H., and Thomas H. Cormen. Introduction to Algorithms). This post will introduce the implementations of two functions union(u,v) and find(p), and provide more details using Leetcode 200. Number of Islands as an example.

find(p), union(u,v), and optimization

There are two optimizations in the two functions: path compression and merge by rank.

find(p) and path compression

Two disjoint sets.

Given an element p, find(p) will return the representative of the set that p belongs to. Initially, we have an array root indicating the root of each element. Therefore, we can recursively or iteratively search for the root of p through root.

root=[0,0,0,0,4,4,5,5,7]
# recursively
def find(p):
	if root[p]!=p:
		return find(root[p])
	return p
# iteratively
def find(p):
	while root[p]!=p:
		p = root[p]
	return p

We can add path compression as optimization. While we are searching for the root of p, we can assign the root to the elements along the path. Also there will be two ways of implementing this.

Path compression.
# recursively
def find(p):
    if root[p]!=p:
        root[p] = find(root[p])
    return root[p]
# iteratively
def find(p):
    node = p
    while node!=root[node]:
        node = root[node]
    while p!=node:
        par = root[p]
        root[p] = node
        p = par
    return p

union(u,v) and merge by rank

Given two elements u and v, union(u,v) merges the sets that u and v belong to accordingly into one. To avoid the case shown below, we can add merge by rank as optimization.

We prefer (a) over (b) because (b) would lead to a very deep tree, consuming more time to find the root.

We can have an array rank indicating the height of each node and when we merge two sets, we would always seek to put the set with lower rank under the set with higher rank.

def union(u,v):
    u_root = find(u)
    v_root = find(v)
    if rank[u_root]>rank[v_root]:
        root[v_root] = u_root
    elif rank[u_root]<rank[v_root]:
        root[u_root] = v_root
    else:
        root[v_root] = u_root
        rank[u_root] += 1

Complexities

Without path compression and merge by rank, the time complexity for find(p) could be O(n) and

With path compression and merge by rankNo optimization
union(u,v)Nearly O(1)O(n)
find(p)Nearly O(1)O(n)
Time complexities.

Leetcode 200. Number of Islands

Initially, we would assign all '1' element as an isolated island. While we are iterating from top to bottom and from left to right, if we find its right neighbour or its neighbour below is also '1', we can conduct union(u,v). Remember to deduct 1 from the total number of the island when we merge two sets.

class Solution:
    def numIslands(self, grid: List[List[str]]) -> int:
        if not grid or not grid[0]: return 0
        row,col = len(grid),len(grid[0])
        root = [i for i in range(row*col)]
        ranks = [0]*(row*col)
        cnt = 0
        for r in range(row):
            for c in range(col):
                # count each '1' as an isolated island
                if grid[r] == '1':
                    cnt += 1
        def find(p):
            # add path compression
            if root[p]!=p:
                root[p] = find(root[p])
            return root[p]
        
        def union(u,v):
            # add merge by rank
            nonlocal cnt
            u_root = find(u)
            v_root = find(v)
            if u_root == v_root: return
            if ranks[u_root] > ranks[v_root]: root[v_root] = u_root
            elif ranks[u_root] < ranks[v_root]: root[u_root] = v_root
            else:
                root[v_root] = u_root
                ranks[u_root] += 1
            # remember to deduct 1 from the total number of islands
            cnt -= 1
            
        for r in range(row):
            for c in range(col):
                if grid[r] == '0': continue
                # union connected '1's
                if r+1<row and grid[r+1] == '1': union(r*col+c,(r+1)*col+c)
                if c+1<col and grid[r] == '1': union(r*col+c,r*col+c+1)
        return cnt

Reference

  • https://www.youtube.com/watch?v=VJnUwsE4fWA
  • https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments