Given a positive integer n, find the least number of perfect square numbers (for example,1, 4, 9, 16, ...
) which sum to n. Example 1: Input: n =12
Output: 3 Explanation:12 = 4 + 4 + 4.
This problem can be solved in three ways: Dynamic Programming, BFS, and Math. Applying Dynamic Programming is less efficient as BFS, since it unnecessarily explores the non-least number of perfect square numbers. Moreover, according to Lagrange’s four-square theorem, with BFS we would be able to find the answer within at most four iterations. Let’s take a look at each approach now.
Dynamic programming
Bottom up
Initialise an array dp[]
with length equals to n+1
, and set dp[0] = 0
. Iterating from 1
to n
, dp[i]
is the least number of perfect square numbers summed to n
. So we can have the transition function as:
dp[i] = min(dp[i-square] if i-square>=0), square refers to square numbers
class Solution:
def numSquares(self, n: int) -> int:
dp = [float('inf')] * (n+1)
dp[0] = 0
for i in range(1,n+1):
for j in range(1,int(i**0.5)+1):
if i-j**2 >= 0:
dp[i] = min(dp[i],dp[i-j**2]+1)
else:
break
return dp[-1]
Top down
Starting from n
, we define a recursive function helper(node)
which returns the least number of square numbers summed to node
, and the exit condition is when node
itself is a square number, we could simply return 1
.
class Solution:
def numSquares(self, n: int) -> int:
opts = set([])
for i in range(1,n+1):
if i ** 2 > n:
break
opts.add(i**2)
dic = {}
def helper(node=n):
if node in opts:
return 1
else:
if node not in dic:
ret = float('inf')
for opt in opts:
if node - opt > 0:
ret = min(ret,helper(node-opt)+1)
dic[node] = ret
return dic[node]
return helper()
Or in Python, we could use @lru_cache(None) from functools
package.
class Solution:
def numSquares(self, n: int) -> int:
opts = set([])
for i in range(1,n+1):
if i ** 2 > n:
break
opts.add(i**2)
@lru_cache(None)
def helper(node=n):
if node in opts:
return 1
else:
ret = float('inf')
for opt in opts:
if node - opt > 0:
ret = min(ret,helper(node-opt)+1)
return ret
return helper()
Breadth first search
From n to square numbers

Treating each number as a node, two nodes are connected when the difference between them is a square number. Therefore we can apply BFS to this tree, rooted as n
.
Notably, to avoid repeated computations, we can implement a hashset to record the visited numbers.
class Solution:
def numSquares(self, n: int) -> int:
i = 1
opts = []
while i**2 <= n:
opts.append(i**2)
i += 1
queue = collections.deque()
queue.append((n,0))
# avoid repeated computations for visited node
visited = set([])
while queue:
num,step = queue.popleft()
if num in visited:
continue
visited.add(num)
for opt in opts:
if num == opt:
return step + 1
if num-opt < 0:
break
queue.append((num-opt,step+1))
Or, another way of BFS:
class Solution:
def numSquares(self, n: int) -> int:
i = 1
opts = []
while i**2 <= n:
opts.append(i**2)
i += 1
# initialised as a hashset
# avoid repeated computations for visited node
q = {n}
step = 0
while q:
step += 1
new_q = set([])
for num in q:
for opt in opts:
if num == opt:
return step
if num-opt < 0:
break
new_q.add(num-opt)
q = new_q
return step
From square numbers to n

Similarly, we would treat this as a tree, but start with square numbers this time. Exit condition is when we find the target n
.
class Solution:
def numSquares(self, n: int) -> int:
q = []
opts = []
for i in range(1,n+1):
if i ** 2 > n:
break
opts.append(i**2)
q.append(i**2)
if opts[-1] == n:
return 1
step = 1
visited = set([])
while q:
size = len(q)
step += 1
for i in range(size):
num = q.pop(0)
for opt in opts:
if num + opt == n:
return step
elif num + opt < n and num+opt not in visited:
visited.add(num+opt)
q.append(num+opt)
elif num + opt > n:
break
Math
If you are interested in solving it mathematically, feel free to check this Leetcode Discuss. This approach is, as mentioned, based on Lagrange’s four-square theorem.