프로젝트 오일러 048
1~1000까지 밑과 지수가 같은 거듭제곱 수들의 합
문제
간단한 풀이
사실 아주 간단한 문제라서 one-liner로 푸는 것도 가능합니다.
print(str(sum(i**i for i in range(1, 1001)))[-10:])
큰 수의 거듭제곱
이미 20번 문제에서 큰 수의 덧셈을 직접 구현해 본 바가 있습니다. 조금 더 좋은 성능이 필요합니다. 999의 999 거듭제곱을 998번의 곱셈으로 계산하기 보다는 좀 더 나은 성능으로 튜닝할 필요가 있습니다. xk × xk = x2k 인 지수법칙을 이용하면 거듭제곱을 계산하는 과정은 다음의 알고리듬을 적용할 수 있습니다.
- x0 = 1
- x1 = x
- n 이 짝수인 경우, xn = x(n÷2) × x(n÷2)
- n 이 홀수(2k + 1)인 경우, xn = x(n - 1) × x = xk × xk × x
def s_pow(xs: list[int], n: int) -> list[int]:
if n == 0:
return 1
elif n == 1:
return xs
t = s_pow(xs, n // 2)
if n % 2 == 0:
return s_multi(t, t)
return s_multi(multi(t, t), xs)
성능개선
거듭제곱의 계산에 대해서 이렇게 최적화를 한다고는 했지만, 원래의 덧셈과 곱셈이 그리 효율적이지가 않습니다. 이 때 중요한 힌트가 하나 있는데, 바로 답에서 요구하는 것은 마지막 10자리만 필요하다는 점입니다. 덧셈이나 곱셈에서는 큰 자리의 숫자가 값에서 그 보다 작은 자리의 숫자에 대해서는 영향을 끼치지 못합니다. 따라서 각각의 곱셈과 덧셈에서 10자리까지만 계산하도록 하는 것입니다. 거듭제곱 계산에서는 수십~수백자리 숫자들의 계산이 반복되기 때문에, 불필요한 자릿수의 계산을 제거하는 것만으로도 이 문제에서는 상당한 성능을 개선할 수 있습니다.
def parse(s: str) -> list[int]:
return [int(x) for x in s[::-1]]
def dump(ns: list[int]) -> str:
return "".join(f"{x}" for x in ns[::-1])
def s_add_10(*xs: list[int]) -> list[int]:
w = min(10, max(map(len, xs)))
grid = [0] * (w * len(xs))
res = []
for i, row in enumerate(xs):
grid[i * w : i * w + min(len(row), 10)] = row[: min(len(row), 10)]
f = 0
for i in range(w):
f, e = divmod(sum(grid[i::w]) + f, 10)
res.append(e)
if f > 0:
res.append(f)
return res[:10]
def s_multi_10(a: list[int], b: list[int]) -> list[int]:
temp: list[list[int]] = []
for i, x in enumerate(a[:10]):
j = [0] * i
f = 0
for y in b[:10]:
f, e = divmod(x * y + f, 10)
j.append(e)
if f > 0:
j.append(f)
temp.append(j[:10])
return s_add_10(*temp)
def s_pow_10(xs: list[int], n: int) -> list[int]:
if n == 0:
return [1]
if n == 1:
return xs[:]
t = s_pow_10(xs, n // 2)
p = s_multi_10(t, t)
return p if n % 2 == 0 else s_multi_10(p, xs)
def multi_10(a: str, b: str) -> str:
x, y = map(parse, (a, b))
return dump(s_multi_10(x, y))
def pow_10(a: str, b: int) -> str:
return dump(s_pow_10(parse(a), b))
def add_10(*xs: str) -> str:
ys = [parse(x) for x in xs]
return dump(s_add_10(*ys))
def main(L: int = 1001):
xs = [pow_10(f"{i}", i) for i in range(1, L)]
return add_10(*xs)
if __name__ == "__main__":
print(main())
print(str(sum(i**i for i in range(1, 1001)))[-10:])