In computer science, a disjoint-set data structure … is a data structure that stores a collection of disjoint (non-overlapping) sets. … It provides operations for adding new sets, merging sets (replacing them by their union), and finding a representative member of a set.Source: https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Disjoint Set helps to group distinct elements into a collection of disjoint sets. There are two major functions associated with it: finding the set that a given element belongs to and merging two sets into one (Cormen, Thomas H., and Thomas H. Cormen. Introduction to Algorithms). This post will introduce the implementations of two functions union(u,v) and find(p), and provide more details using Leetcode 200. Number of Islands as an example.
union(u,v), and optimization
There are two optimizations in the two functions: path compression and merge by rank.
find(p) and path compression
Given an element
find(p) will return the representative of the set that
p belongs to. Initially, we have an array
root indicating the root of each element. Therefore, we can recursively or iteratively search for the root of
root=[0,0,0,0,4,4,5,5,7] # recursively def find(p): if root[p]!=p: return find(root[p]) return p # iteratively def find(p): while root[p]!=p: p = root[p] return p
We can add path compression as optimization. While we are searching for the root of
p, we can assign the root to the elements along the path. Also there will be two ways of implementing this.
# recursively def find(p): if root[p]!=p: root[p] = find(root[p]) return root[p] # iteratively def find(p): node = p while node!=root[node]: node = root[node] while p!=node: par = root[p] root[p] = node p = par return p
union(u,v) and merge by rank
Given two elements
union(u,v) merges the sets that
v belong to accordingly into one. To avoid the case shown below, we can add merge by rank as optimization.
We can have an array
rank indicating the height of each node and when we merge two sets, we would always seek to put the set with lower rank under the set with higher rank.
def union(u,v): u_root = find(u) v_root = find(v) if rank[u_root]>rank[v_root]: root[v_root] = u_root elif rank[u_root]<rank[v_root]: root[u_root] = v_root else: root[v_root] = u_root rank[u_root] += 1
Without path compression and merge by rank, the time complexity for find(p) could be O(n) and
|With path compression and merge by rank||No optimization|
Initially, we would assign all
'1' element as an isolated island. While we are iterating from top to bottom and from left to right, if we find its right neighbour or its neighbour below is also
'1', we can conduct
union(u,v). Remember to deduct
1 from the total number of the island when we merge two sets.
class Solution: def numIslands(self, grid: List[List[str]]) -> int: if not grid or not grid: return 0 row,col = len(grid),len(grid) root = [i for i in range(row*col)] ranks = *(row*col) cnt = 0 for r in range(row): for c in range(col): # count each '1' as an isolated island if grid[r][c] == '1': cnt += 1 def find(p): # add path compression if root[p]!=p: root[p] = find(root[p]) return root[p] def union(u,v): # add merge by rank nonlocal cnt u_root = find(u) v_root = find(v) if u_root == v_root: return if ranks[u_root] > ranks[v_root]: root[v_root] = u_root elif ranks[u_root] < ranks[v_root]: root[u_root] = v_root else: root[v_root] = u_root ranks[u_root] += 1 # remember to deduct 1 from the total number of islands cnt -= 1 for r in range(row): for c in range(col): if grid[r][c] == '0': continue # union connected '1's if r+1<row and grid[r+1][c] == '1': union(r*col+c,(r+1)*col+c) if c+1<col and grid[r][c+1] == '1': union(r*col+c,r*col+c+1) return cnt