알고리즘

백준 16136 준하의 정수론 과제 (Divmaster)

gubshig 2023. 4. 11. 10:50

https://www.acmicpc.net/problem/16136

 

16136번: 준하의 정수론 과제 (Divmaster)

준하는 3학년 2학기 때 들으려고 했던 정수론을 수강신청을 잘못하는 바람에 2학년 1학기 때 신청하고 말았다! 사악한 정수론 선생님은 자연수의 약수의 개수를 구하는 문제를 던지고, 이 문제

www.acmicpc.net

일반적인 세그먼트 트리나 레이지 세그먼트 트리로는 해결할 수 없음을 알 수 있다. 

관찰을 하나 하면 문제가 쉽게 풀린다. 

약수의 개수 함수를 d(n)이라고 하자. d(n)을 반복하다보면 그 수가 1이 아닌 이상 2에 수렴할 것인데, 그 속도가 상당히 빠름을 알 수 있다. 어떤 수의 약수의 개수의 상한은 2 * sqrt(n) 정도 되기 때문이다. 

그렇다면 어떤 수가 2에 수렴했을 때 봐주지 않게 처리를 해줘야 하는데, 이는 왼쪽을 관리하는 분리집합과 오른쪽을 관리하는 분리집합을 사용하면 쉽게 해결할 수 있다. 어떤 수가 2 or 1에 수렴했을 때, 오른쪽과 합쳐주면 다음에 그 구간을 안볼 수 있게 된다.

d(n)은 O(n log n) 정도에 전처리를 해줄 수 있다.

import sys

input = sys.stdin.readline
mis = lambda: map(int, input().split())
MX = int(1e6) + 1

class Segtree:
    def __init__(self, a, SZ):
        self.SZ = SZ
        self.st = [0] * SZ + a
        for i in range(self.SZ - 1, 0, -1):
            self.st[i] = self.st[i * 2] + self.st[i * 2 + 1]

    def update(self, x, v):
        x += self.SZ
        self.st[x] = v
        while x > 1:
            self.st[x // 2] = self.st[x] + self.st[x ^ 1]
            x //= 2

    def query(self, l, r):
        r += 1
        l += self.SZ
        r += self.SZ
        ret = 0
        while l < r:
            if l & 1:
                ret += self.st[l]
                l += 1
            if r & 1:
                r -= 1
                ret += self.st[r]
            l //= 2
            r //= 2
        return ret


class DisjointSet:
    def __init__(self, sz):
        self.p = [i for i in range(sz)]

    def find(self, x):
        while self.p[x] != x:
            self.p[x] = self.p[self.p[x]]
            x = self.p[x]
        return x

    def union(self, x, y):
        x, y = self.find(x), self.find(y)
        if x == y: return
        self.p[y] = x

n, q = mis()
a = [0] + list(mis()) + [0]
seg = Segtree(a, n + 2)
d = [0] * MX
for i in range(1, MX):
    for j in range(i, MX, i): d[j] += 1

lset = DisjointSet(n + 2)
rset = DisjointSet(n + 2)

for _ in range(q):
    Q, l, r = mis()
    if Q == 1:
        l = lset.find(l)
        r = rset.find(r)
        while l <= r:
            l = lset.find(l)
            seg.update(l, d[a[l]])
            a[l] = d[a[l]]
            if a[l] == 2 or a[l] == 1:
                lset.union(l + 1, l)
                rset.union(l - 1, l)
            l += 1
    else:
        print(seg.query(l, r))

시간복잡도는 O(max(a) log max(a)) + O(n log log max(a)) + O(q log n) = O(max(a) log max(a) + n log log max(a) + q log n) 가 된다. (d(n) 전처리 + a가 2 or 1에 수렴하기 까지 걸리는 시간 + 세그먼트 트리로 구간합을 처리하는 시간)