Site icon Wireframe

프로젝트 오일러 92

어떤 자연수의 각 자리 숫자를 제곱하여 그 합을 구하는 계산을 반복하면 1이 되는 경우가 많다. 이런 숫자들은 happy number 라고 부른다. 행복하지 못한 숫자들은 계산을 반복하는 과정에서 89를 만나게 되고, 89는 다시 몇 단계의 과정을 거쳐 89가 되어 영원히 1이 될 수 없다.

각 자리 숫자의 제곱의 합은 간단히 구할 수 있기 때문에 간단한 코드로 어떤 숫자가 행복한지 여부를 알아내는 것은 쉽지만, 문제는 그 범위가 1천만 개나 된다는 점이다. 간단한 작업이지만 백만 번을 넘게 반복한다면 제법 오랜 시간이 걸릴 것이다. 이 문제에 대한 해법으로 전수검사를 통해 카운트하는 코드는 조금만 검색해보면 여러 블로그를 통해 쉽게 찾을 수 있지만, 더 빠른 시간 내에 답을 구하기 위한 방법을 찾아보도록하자.

먼저 1천만 이하의 수를 각 한 번씩만 자리 수 제곱의 합을 구한다고 할 때의 최대값이 어떻게 될지 생각해보자. 이 값이 최대가 되는 경우는 9,999,999일 때가 됨은 자명하고 그 때의 자리수 제곱의 합은 9*9*7 = 567이다. 즉 1~567에 해당하는 수 중에서 어떤 수가 불행한지를 알고 있다면, 1천만 이하의 모든 수에 대해 한 번씩만 자리 수 제곱의 합을 구하면 그 결말이 행복한지는 쉽게 알 수 있을 것이다.

def is_unhappy(n: int) -> bool:
  while n > 1 and n != 4 and n != 89:
    n = sum(int(c)**2 for c in str(n))
  return n > 1

unhappy = set(x for x in range(1, 568) if is_unhappy(x))

이렇게 필터를 set로 만들어두면 단 한 번만 변환하여 멤버심 검사를 통해서 쉽게 불행한 수인지 알 수 있다. 89 혹은 1을 만나는지 여부를 계속 추적하지 않기 때문에 이 정도도 기존 코드들에 비해서는 매우 빠른 편이다.

%time print(sum(1 for x in range(1, 1000_0000) if sum(int(c) ** 2 for c in str(x)))

Wall time: 9.29s

순열인 수들

여기서 좀 더 시간을 단축할 수 있는 방법이 있을까? 예를 들어 불행한 수 중의 하나인 1234를 한 번 변환하면 1 + 4 + 9 + 16 = 30 이 된다. 그런데 1324, 4123, 3412 등 1234와 같은 숫자로 이루어진 수들의 변환 값도 모두 30이다. 즉 1234가 불행한 수인지 알 수 있다면, 그 순열들도 모두 불행한 수인지를 계산해보지 않고서도 알 수 있다.

천만 이하의 자연수들에는 1223 과 같이 같은 숫자가 2개 이상 들어가는 값도 있으므로, 같은 순열로 된 수들의 개수는 중복을 포함하는 순열의 경우의 수에 해당한다. 이 값은 학교에서 배우는 공식으로 간단히 계산할 수 있다. 그렇다면 같은 순열을 제외한 중복을 포함하는 조합들을 만들 수 있다면, 한 번의 계산으로 불행한 수인지를 파악하고, 같은 순열의 개수를 모두 더해나가면 최종 답을 구할 수 있을 것이다.

중복 조합 생성 및 중복을 포함하는 순열의 개수를 구하는 함수를 준비하자.

from collections import Count

# 중복조합생성
def dup_combinations(xs, n=0):
  def helper(head, tail, k):
    if k == 0:
      yield head
      return
    for (i, t) in enumerated(tail):
      yield from helper((*head, t), tail[i:], k-1)


# 중복을 포함하는 순열의 개수

def factorial(n: int) -> int:
  if n > 1:
    s = n
    for i in range(1, n):
      s *= i
    return s
  return 1
  
def count_dup_perm(xs):
  cs = Counter(xs)
  n = factorial(len(xs))
  for v in cs.values():
    n //= factorial(v)
  return n

이상의 내용들을 사용한 최종 계산은 다음과 같이 할 수 있다.

def main():
  unhappy = set(x for x in range(1, 1000_0000) if is_unhappy(x))
  res = 0
  for ns in dup_combinations(range(10), n=7):
    if sum(n*n for n in ns) in unhappy:
      res += count_dup_perm(ns)
  print(res)



if __name__ == "__main__":
  main()

최적화 후 실행 시간은 0.01초 내외로 빠르게 계산된다.

Exit mobile version