프로젝트 오일러 035

백만 이하의 197, 971, 719와 같은 원형 소수 찾기

프로젝트 오일러 035
Photo by Mike Gattorna / Unsplash

문제

35번 문제
백만 미만인 원형 소수 개수 구하기

숫자를 원형으로 순환하기

어떤 정수로부터 한 방향으로 한 숫자씩 밀어서 순환하는 수들의 목록을 만드는 것이 우선 가장 필요한 부분입니다. 여기서는 두 가지 방법이 있을 수 있습니다. 우선 첫 번째는 문자열을 이용하는 방법입니다.

def shift_1(n: int) -> set[int]:
  s = f'{n}{n}'
  l = len(s) // 2
  return set(int(s[i:i+l]) for i in range(l))

이 방법은 언뜻 생각하기에 문자열로 변환된 숫자를 다시 정수로 만드는 것에 있어서 오버헤드가 상당할 것으로 생각되지만, 파이썬의 최근 버전들은 문자열 처리 관련한 최적화에 상당히 공을 들였기 때문에 그렇게 느린 편은 아닙니다.

다른 방법으로는 1의 자리만 떼어내어, 순서를 바꿔 다시 더하는 방법입니다. 떼어낸 1의 자리에 10을 몇 번 곱해야 할지는 원래의 값이 몇 자리 수인지에 의해 결정됩니다. 이를 알아내려면 상용로그(log10)를 이용할 수 있습니다.

def f2(n: int) -> set[int]:
    l = int(log10(n))
    return set(
        (n // 10 ** (i + 1)) + (n % 10 ** (i + 1)) * (10 ** (l - i))
        for i in range(l + 1)
    )

둘 중의 어떤 것을 써도 상관은 없습니다.

범위

문제에서 아주 중요한 힌트를 하나 주고 있습니다. 바로 1백만 이하의 모든 원형 소수를 찾으라고 했습니다. 결국 100만 이하의 수가 소수일 때, 그 수를 원형으로 순환시킨 모든 수가 소수인지를 검사하면 되는 문제 같습니다. 이미 우리는 꽤 쓸만한 소수 검사 함수도 작성해 보았고, 이 문제에서는 원형으로 순환하는 목록을 만드는 함수도 작성했습니다.

하지만 이 문제의 가장 중요한 힌트는 '백만 이하'입니다. 범위가 정해져 있다는 이야기입니다. 정해진 범위 내의 모든 소수를 찾는 가장 빠른 방법은 에라토스테네스의 체입니다. 이 소수들을 하나의 set에 넣어둔다면, 백만 이하의 정수에 대해 멤버십 테스트를 하는 것만으로, 최소한의 시간으로 소수인지 여부를 검사할 수 있습니다.

위에서 순환 그룹을 만들 때 리스트가 아닌 세트를 생성한 것도 그러한 이유입니다. set는 차집합, 교집합, 합집합을 만드는 연산자가 정의되어 있고, <, > 연산자를 사용해서 포함관계를 표현할 수도 있습니다. 특정한 수의 순환 집합이 모두 소수라면, 소수 set의 부분집합인지를 검사하여 모두 소수인지를 쉽게 알아낼 수 있습니다.

가드

사실 에라토스테네스의 체를 이용하는 것만으로도 충분히 시간은 단축됩니다. 그 외에 범위를 축소할 수 있는 아이디어로는 어떤 것이 있을까요? '2'를 제외하고, 두 자리 이상의 수에서 2가 어딘가에 있다면 이 수의 순환 숫자 중에는 2로 끝나서 소수가 될 수 없는 수가 포함됩니다. 이런식으로 끝자리에 도달했을 때 소수가 아니게 만드는 숫자로는 '0, 2, 4, 5, 6, 8'이 있습니다. 따라서 두 자리 이상의 소수 중에서 '1, 6, 7, 9'로만 구성된 소수들만이 순환 소수가 될 수 있습니다.

하지만 set의 멤버십 검사는 매우 빠르기 때문에, 가드 적용을 위한 처리가 반대로 오버헤드로 작용할 수 있습니다. 따라서 이 아이디어는 제외하고 구현합니다.

from math import log10


def shift(n: int) -> set[int]:
    """숫자를 순환시킨 집합"""
    l = int(log10(n)) + 1
    return set( (n // (10 ** i)) + (n % (10 ** i) * (10 ** (l -i))) for i in range(l))


def sieve(n: int) -> list[int]:
    """에라토스테네스의 체"""
    s = [True] * (n + 1)
    s[:2] = [False, False]
    for i in range(2, int(n ** 0.5) + 1):
        if s[i]:
            s[i*2::i] = [False] * ((n - i) // i)
    return [i for (i, x) in enumerate(s) if x]


def main():
    s = sieve(100_0000)
    x = set(s)
    res = 0
    for p in s:
        ps = shift(p)
        if x > ps:
            res += 1

    return res


print(main())