프로젝트 오일러 056

100미만의 a, b의 a**b로 표현할 수 있는 자릿수의 합의 최대값

프로젝트 오일러 056
Photo by Susan Wilkinson / Unsplash

문제

56번 문제
<var>a<sup>b</sup></var> 형태의 자연수에 대해 자릿수 합의 최댓값 구하기

One-liner

1~99 사이의 a, b에 대해 각각 a의 b거듭제곱의 값을 구하고, 이를 각 자리 숫자로 쪼개어 그 합을 구합니다. 그 합과 다른 정보를 튜플로 만들면, 그 때의 a, b도 알 수 있습니다. (읽기 쉽게 여러 줄로 나눴지만, 한 줄의 명령입니다.

print(max((sum(map(int, str(a**b))), a, b) 
          for a in range(1, 100) 
          for b in range(1, 100)
          )
      )

다항식을 이용한 계산

이전 문제들을 풀어보면서 큰 정수의 계산 기능을 이용하는 것이 반칙처럼 느껴져서 add, multi, pow 같은 함수들을 작성해서 한자리 정수의 배열을 사용해서 큰 정수의 계산을 구현한 바 있습니다. 그런데 이 문제에서는 이 함수들을 사용하면 그닥 좋은 성능이 나지 않습니다. 물론 이러한 함수들이 별로 성능이 좋지 못한 것이 가장 큰 이유겠지요. 그래서 약간의 튜닝이 필요합니다.

이전에는 각각의 배열은 십진법을 사용한 계산이었습니다. 배열의 각 항은 한 자리 정수였고, 덧셈이나 곱셈에서는 항상 10으로 나눈 나머지를 그 자리의 수로 사용했습니다. 그렇지만 1만진법이라면 어떨까요? 한 자리 숫자를 표현할 만 개의 기호는 없지만, 그저 배열의 항 하나가 0~9999까지의 범위의 값을 갖는 것으로 보는거죠. 이렇게하면 배열의 크기가 최대 1/4까지 줄어들기에 전체 계산에 소요되는 루프의 수를 크게 줄일 수 있습니다.

다른 부분들의 로직은 완전히 동일하지만, 문자열을 뒤에서부터 4자리씩 '올바른 순서'대로 끊어야 하는 부분 때문에 정수의 리스트로 변환하는 부분이 조금 복잡해 보입니다.

from euler import elapsed

T_SIZE = 4
LV = 10**T_SIZE


def parse(s: str) -> list[int]:
    l = len(s)
    w = int(l / 4 + 0.99)
    return [int(s[max(0, l - (i + 1) * T_SIZE) : l - ((i) * T_SIZE)]) for i in range(w)]


def dump(ns: list[int]) -> str:
    x, *ys = ns[::-1]
    return str(x) + "".join(f"{y:04d}" for y in ys)


def __add(*args: list[int]) -> list[int]:
    w = max(map(len, args))
    grid = [0] * ((w) * len(args))
    res = []
    for i, row in enumerate(args):
        grid[i * w : i * w + len(row)] = row
    f = 0
    for i in range(w):
        f, e = divmod(sum(grid[i::w]) + f, LV)
        res.append(e)
    if f > 0:
        res.append(f)
    return res


def __multi(a: list[int], b: list[int]) -> list[int]:
    res = []
    for i, x in enumerate(a):
        j = [0] * i
        f = 0
        for y in b:
            f, e = divmod(x * y + f, LV)
            j.append(e)
        if f > 0:
            j.append(f)
        res.append(j)
    return __add(*res)


def __pow(ns: list[int], e: int) -> list[int]:
    if e == 0:
        return [1]
    if e == 1:
        return ns[:]
    ts = __pow(ns, e // 2)
    res = __multi(ts, ts)
    if e % 2 == 0:
        return res
    return __multi(res, ns)


def s_pow(a: str, b: int) -> str:
    return dump(__pow(parse(a), b))


@elapsed
def main():
    def foo(a: int, b: int) -> tuple[int, int, int]:
        x = s_pow(str(a), b)
        return (sum(map(int, x)), a, b)

    res = max(foo(a, b) for a in range(1, 100) for b in range(1, 100))
    print(res)


if __name__ == "__main__":
    main()