콘텐츠로 건너뛰기
Home » 오일러 프로젝트 56

오일러 프로젝트 56

구골(googol)은 10100을 일컫는 말로, 1뒤에 0이 백개나 붙는 어마어마한 수입니다. 100100은 1뒤에 0이 2백개나 붙으니 상상을 초월할만큼 크다 하겠습니다. 하지만 이 숫자들이 얼마나 크건간에, 각 자리수를 모두 합하면 둘 다 겨우 1밖에 되지 않습니다. a, b가 100보다 작은 자연수일 때, ab에 대해서, 자릿수의 합이 최대인 경우, 그 값은 얼마입니까?

http://euler.synap.co.kr/prob_detail.php?id=55

접근

1~99 사이의 a, b를 사용하여 a**b를 계산한 다음, 이를 문자열로 만들어서 각 숫자를 쪼개어 합산하고 그 중에서 가장 큰 값을 구하는 것이니 다음과 같이 간단하게 계산할 수 있다.

print(max(sum(map(int, str(a**b))) for a in range(1, 100) for b in range(1,100)))

큰 정수를 사용하지 않는 풀이

큰 정수를 사용하지 않고, 문자열로 큰 정수를 표현하여 덧셈과 곱셈을 계산하는 함수를 이미 만들어 둔 것이 있고, 이를 활용해서 거듭제곱도 계산하는 함수를 만든 바 있다. 이를 사용하여 같은 방식으로 문제를 해결할 수 있을 것이다.

print(max(sum(map(int, s_pow(f'{a}', b)) for a in range(1, 100) for b in range(1,100)))

# Wall time: 17.5s

개선

그런데 s_pow() 함수는 내부적으로 여러 번 s_multi() 함수를 호출한다. 그리고 s_multi() 함수는 다시 s_add() 함수를 호출한다. 각각의 함수들은 문자열은 분해했다가 다시 재조립하는 과정을 반복해서 거칠 뿐만 아니라, 각 자리 숫자 1개씩을 정수로 분리하여 계산하기 때문에 자릿수가 크면 클수록 곱셈이 느려진다. 따라서 위 코드는 실제로 답을 내는데 수십초 가량이 걸릴 정도로 너무 느리다.

큰 정수를 사용하는 덧셈과 곱셈 함수의 성능을 개선해보자. 앞서 두 가지 느린 포인트를 지적했다. 우선 거듭제곱을 계산하기 위해서 반복적으로 문자열을 정수의 리스트로, 다시 정수의 리스트를 문자열로 반복하는 과정을 중간 단계에서 반복한다. 이 과정에서는 중간단계에서는 필요가 없다. 따라서 덧셈과 곱셈을 실제로 계산하는 함수는 문자열 변환을 수행하지 않고 처음과 최종단계에서만 변환하도록 하면 불필요한 연산을 줄일 수 있다.

그리고 리스트의 크기를 줄이는 것을 생각해 볼 수 있다. 기존 함수에서 100자리 수와 100자리 수를 곱하면 내부에서는 모두 10,000번의 곱셈을 수행한다. 이 횟수를 줄이려면 ‘자리수’를 줄일 필요가 있다. 예를 들어 1234 * 1234 를 생각해보자. 모두 16번의 곱셈을 수행한다. 그런데 십진수 1234를 16진수로 표현하면 ‘4D2’가 되고, 4D2 * 4D2 는 곱셈을 아홉번만 수행하면 된다. 전체적인 반복횟수를 줄이게 된다. 10진수를 16진수로 표현하는 것으론 자리 수 감소에 한계가 있다. 하지만 숫자 1자리가 10,000개의 값을 표현할 수 있다면 자리 수는 크게 줄일 수 있을 것이다. 10진수로 100자리 수 두 개를 곱하는 것은 1만진수 25자리가 될 것이기 때문에, 10,000 번에서 625번으로 곱셈 횟수를 획기적으로 줄인다. 물론 1만진수로 표현할 방법은 없지만, 리스트에서 0~9999 사이의 값으로 원래값을 쪼갠다면 리스트로는 표현할 수 있다.

왜 1만을 단위로 했냐면, 9,999의 제곱은 32비트 정수의 최대값보다 작기 때문이다. 99,999의 제곱은 이 범위를 벗어난다. 따라서 이 방법을 그대로 사용하면 32비트 정수를 사용해야 하는 언어에서도 활용할 수 있다.

문자열 변환

문자열을 1만진법 수와 서로 변환하는 함수를 작성해보자. 주의할 것은 뒤에서부터 4자리씩 끊어야 한다는 점이다.

def _parse(s: str) -> list[int]:
  l = len(s)
  return [int(s[max(l-4*i-4, 0):l-4*i]) 
          for i in range(int(l / 4) + 0.8))
  ]

def _comps(xs: list[int]) -> str:
  return ''.join(f'{c:04d}' for c in xs[::-1]).lstrip('0')

print(_parse('1234567'))
# -> [4567, 123]

print(_comps(_parse('1234567'))
# -> 1234567

버려지는 부분이 없도록 문자열의 길이가 4의 배수가 아닌 경우에는 1을 더해야 한다. 1/4 = 0.25 이므로 0.9를 더한 후 int() 함수로 소수점 이하를 잘라버리면 깔끔하게 해결할 수 있다.

덧셈 함수

덧셈함수를 다음과 같이 작성한다. 표면적으로 사용하는 함수와 내부적으로 계산하는 함수로 구분한다. 계산함수는 리스트로 표현된 1만진수를 취급하며, 표면적으로 사용하는 함수는 입력을 변환하고, 계산결과를 다시 거꾸로 변환하여 리턴한다. 계산시에 가장 긴 자리에 맞춰주기 위해서 이전에는 문자열의 zfill() 메소드를 사용했는데, 이번에는 직접 [0] 을 부족한 자리수만큼 이어주는 방식을 선택했다.

def s_add(*args):
    return _comp(_s_add(*map(_parse, args)))


def _s_add(*args):
    l = max(map(len, args))
    args = [x + [0] * (l - len(x)) for x in args]
    res = []
    z = 0
    for xs in zip(*args):
        z, w = divmod(sum(xs) + z, 10000)
        res.append(w)
    if z:
        res.append(z)
    return res

같은 방식으로 곱하기 함수와 거듭제곱함수도 아래와 같이 작성할 수 있다. 사실상 다른 함수로 감쌌다는 것과 10,000으로 나눈 나머지를 각 자리 수의 값으로 사용하는 것외에는 차이가 없다.

def _s_multi(a: list[int], b: list[int]) -> list[int]:
  res = []
  for i, y in enumerate(b):
    z, temp = 0, [0] * i
    for x in a:
      z, w = divmod(x * y + z, 10000)
      temp.append(w)
    if z:
      temp.append(z)
    res.append(temp)
  return _s_add(res)

def s_multi(a: str, b: str) -> str:
  x, y = map(_parse, (a, b))
  return _comps(_s_multi(x, y))

def _s_pow(a: list[int], b: int) -> list[int]:
  if b == 0:
    return [1]
  if b == 1:
    return a
  c = _s_pow(a, b // 2)
  if b % 2 == 0:
    return _s_multi(c, c)
  return s_multi(_s_multi(c, c), a)

def s_pow(a:str, b:int) -> str:
  return _comps(_s_pow(_parse(a), b))

그리고 이렇게 수정한 버전을 사용하여 아까와 똑같은 코드로 실행해보면… 30배 정도로 빨라진 것을 확인할 수 있다.

print(max(sum(map(int, s_pow(f'{a}', b)) for a in range(1, 100) for b in range(1,100)))

# Wall time: 0.579s