프로젝트 오일러 048

1~1000까지 밑과 지수가 같은 거듭제곱 수들의 합

프로젝트 오일러 048
Photo by Soekarno Omar / Unsplash

문제

48번 문제
1<sup>1</sup> + 2<sup>2</sup> + 3<sup>3</sup> + ... + 1000<sup>1000</sup> 의 마지막 10자리

간단한 풀이

사실 아주 간단한 문제라서 one-liner로 푸는 것도 가능합니다.

print(str(sum(i**i for i in range(1, 1001)))[-10:])

큰 수의 거듭제곱

이미 20번 문제에서 큰 수의 덧셈을 직접 구현해 본 바가 있습니다. 조금 더 좋은 성능이 필요합니다. 999의 999 거듭제곱을 998번의 곱셈으로 계산하기 보다는 좀 더 나은 성능으로 튜닝할 필요가 있습니다. xk × xk = x2k 인 지수법칙을 이용하면 거듭제곱을 계산하는 과정은 다음의 알고리듬을 적용할 수 있습니다.

  1. x0 = 1
  2. x1 = x
  3. n 이 짝수인 경우, xn = x(n÷2) × x(n÷2)
  4. n 이 홀수(2k + 1)인 경우, xn = x(n - 1) × x = xk × xk × x
def s_pow(xs: list[int], n: int) -> list[int]:
  if n == 0:
    return 1
  elif n == 1:
    return xs
  t = s_pow(xs, n // 2)
  if n % 2 == 0:
    return s_multi(t, t)
  return s_multi(multi(t, t), xs)

성능개선

거듭제곱의 계산에 대해서 이렇게 최적화를 한다고는 했지만, 원래의 덧셈과 곱셈이 그리 효율적이지가 않습니다. 이 때 중요한 힌트가 하나 있는데, 바로 답에서 요구하는 것은 마지막 10자리만 필요하다는 점입니다. 덧셈이나 곱셈에서는 큰 자리의 숫자가 값에서 그 보다 작은 자리의 숫자에 대해서는 영향을 끼치지 못합니다. 따라서 각각의 곱셈과 덧셈에서 10자리까지만 계산하도록 하는 것입니다. 거듭제곱 계산에서는 수십~수백자리 숫자들의 계산이 반복되기 때문에, 불필요한 자릿수의 계산을 제거하는 것만으로도 이 문제에서는 상당한 성능을 개선할 수 있습니다.

def parse(s: str) -> list[int]:
    return [int(x) for x in s[::-1]]


def dump(ns: list[int]) -> str:
    return "".join(f"{x}" for x in ns[::-1])


def s_add_10(*xs: list[int]) -> list[int]:
    w = min(10, max(map(len, xs)))
    grid = [0] * (w * len(xs))
    res = []
    for i, row in enumerate(xs):
        grid[i * w : i * w + min(len(row), 10)] = row[: min(len(row), 10)]
    f = 0
    for i in range(w):
        f, e = divmod(sum(grid[i::w]) + f, 10)
        res.append(e)
    if f > 0:
        res.append(f)
    return res[:10]


def s_multi_10(a: list[int], b: list[int]) -> list[int]:
    temp: list[list[int]] = []
    for i, x in enumerate(a[:10]):
        j = [0] * i
        f = 0
        for y in b[:10]:
            f, e = divmod(x * y + f, 10)
            j.append(e)
        if f > 0:
            j.append(f)
        temp.append(j[:10])
    return s_add_10(*temp)


def s_pow_10(xs: list[int], n: int) -> list[int]:
    if n == 0:
        return [1]
    if n == 1:
        return xs[:]
    t = s_pow_10(xs, n // 2)
    p = s_multi_10(t, t)
    return p if n % 2 == 0 else s_multi_10(p, xs)


def multi_10(a: str, b: str) -> str:
    x, y = map(parse, (a, b))
    return dump(s_multi_10(x, y))


def pow_10(a: str, b: int) -> str:
    return dump(s_pow_10(parse(a), b))


def add_10(*xs: str) -> str:
    ys = [parse(x) for x in xs]
    return dump(s_add_10(*ys))


def main(L: int = 1001):
    xs = [pow_10(f"{i}", i) for i in range(1, L)]
    return add_10(*xs)


if __name__ == "__main__":
    print(main())
    print(str(sum(i**i for i in range(1, 1001)))[-10:])