# Disjoint Set – Union & Find

In computer science, a disjoint-set data structure … is a data structure that stores a collection of disjoint (non-overlapping) sets. … It provides operations for adding new sets, merging sets (replacing them by their union), and finding a representative member of a set.

Source: https://en.wikipedia.org/wiki/Disjoint-set_data_structure

Disjoint Set helps to group distinct elements into a collection of disjoint sets. There are two major functions associated with it: finding the set that a given element belongs to and merging two sets into one (Cormen, Thomas H., and Thomas H. Cormen. Introduction to Algorithms). This post will introduce the implementations of two functions union(u,v) and find(p), and provide more details using Leetcode 200. Number of Islands as an example.

## find(p), union(u,v), and optimization

There are two optimizations in the two functions: path compression and merge by rank.

#### find(p) and path compression

Given an element p, find(p) will return the representative of the set that p belongs to. Initially, we have an array root indicating the root of each element. Therefore, we can recursively or iteratively search for the root of p through root.

root=[0,0,0,0,4,4,5,5,7]
# recursively
def find(p):
if root[p]!=p:
return find(root[p])
return p
# iteratively
def find(p):
while root[p]!=p:
p = root[p]
return p


We can add path compression as optimization. While we are searching for the root of p, we can assign the root to the elements along the path. Also there will be two ways of implementing this.

# recursively
def find(p):
if root[p]!=p:
root[p] = find(root[p])
return root[p]
# iteratively
def find(p):
node = p
while node!=root[node]:
node = root[node]
while p!=node:
par = root[p]
root[p] = node
p = par
return p


#### union(u,v) and merge by rank

Given two elements u and v, union(u,v) merges the sets that u and v belong to accordingly into one. To avoid the case shown below, we can add merge by rank as optimization. We prefer (a) over (b) because (b) would lead to a very deep tree, consuming more time to find the root.

We can have an array rank indicating the height of each node and when we merge two sets, we would always seek to put the set with lower rank under the set with higher rank.

def union(u,v):
u_root = find(u)
v_root = find(v)
if rank[u_root]>rank[v_root]:
root[v_root] = u_root
elif rank[u_root]<rank[v_root]:
root[u_root] = v_root
else:
root[v_root] = u_root
rank[u_root] += 1


## Complexities

Without path compression and merge by rank, the time complexity for find(p) could be O(n) and

## Leetcode 200. Number of Islands

Initially, we would assign all '1' element as an isolated island. While we are iterating from top to bottom and from left to right, if we find its right neighbour or its neighbour below is also '1', we can conduct union(u,v). Remember to deduct 1 from the total number of the island when we merge two sets.

class Solution:
def numIslands(self, grid: List[List[str]]) -> int:
if not grid or not grid: return 0
row,col = len(grid),len(grid)
root = [i for i in range(row*col)]
ranks = *(row*col)
cnt = 0
for r in range(row):
for c in range(col):
# count each '1' as an isolated island
if grid[r][c] == '1':
cnt += 1
def find(p):
if root[p]!=p:
root[p] = find(root[p])
return root[p]

def union(u,v):
nonlocal cnt
u_root = find(u)
v_root = find(v)
if u_root == v_root: return
if ranks[u_root] > ranks[v_root]: root[v_root] = u_root
elif ranks[u_root] < ranks[v_root]: root[u_root] = v_root
else:
root[v_root] = u_root
ranks[u_root] += 1
# remember to deduct 1 from the total number of islands
cnt -= 1

for r in range(row):
for c in range(col):
if grid[r][c] == '0': continue
# union connected '1's
if r+1<row and grid[r+1][c] == '1': union(r*col+c,(r+1)*col+c)
if c+1<col and grid[r][c+1] == '1': union(r*col+c,r*col+c+1)
return cnt


## Reference 