카테고리 없음

boj 24477 Railway Trip 2

gubshig 2023. 7. 18. 22:20

 

https://www.acmicpc.net/problem/24477

재밌는 JOI 문제이다. 서브태스크마다 풀이를 생각해본 문제는 오랜만인데, 마침 메모장에 적어놓은게 몇가지 있어서 블로그에 올려볼까 한다.

 

subtask 1)
$A(i) \to [A(i) + 1, B(i)]$
$A(i) + 1 \to [A(i) + 2, B(i)]$
$A(i) + 2 \to [A(i) + 3, B(i)]$
$\min(A(i) + k - 1, B(i) - 1) \to [\,, B(i)]$
...
or
$A(i) \to [B(i), A(i) - 1]$
$A(i) - 1 \to [B(i), A(i) - 2]$
$A(i) - 2 \to [B(i), A(i) - 3]$
$\max(A(i) - k + 1, B(i) + 1) \to [B(i), \,]$
...
이런식으로 간선이 만들어진다.
정점 $O(N)$개, 간선 $O(MN^2)$개에서 플로이드를 돌리면 된다.
시간복잡도는 $O(MN^2 + N^3 + Q)$ 이다.

 

subtask 2, 3)

현재 그래프의 모델링은 너무 느리기에 다른 방법을 생각해보자.

각 정점이 이동할 수 있는 정점들을 생각했을 때, 그 정점들은 임의의 연속한 구간을 이룬다.

생각해보면 자명하다.

그렇다면 각 정점에 대해 이동할 수 있는 연속된 구간을 구해주자.

다양한 방법이 있을 것 같은데, 나는 우선순위큐를 들고 스위핑 했다.

 

만약 어떤 정점에서 $K$번 이동했을 때 갈 수 있는 오른쪽 최댓값, 왼쪽 최솟값을 안다고 하면,

각 쿼리는 파라매트릭 서치를 이용해 해결해줄 수 있다.

 

$dpl(i, j)$ = 각 정점에서 $j$번 이동했을 떄 갈 수 있는 왼쪽 최솟값, $dpr(i, j)$ = 오른쪽 최댓값 이라 한다면, $dpl(i, j)$의 값은 구간 $[dpl(i, j - 1), dpr(i, j - 1)]$ 에서의 $dpl(i, j - 1)$의 최솟값, $dpr(i, j)$의 값은 $dpr(i, j - 1)$의 최댓값이라 할 수 있다. 이는 세그먼트 트리 등을 이용해서 구해줄 수 있다.

 

full task)

위의 풀이를 최적화해보자. 비슷한 $dp$식의 정의에서, $j$번 이동했을때가 아닌, $2^j$번 이동했을 때라 하자.

상태 전이는 동일하게 되고, Q개의 쿼리도 각 j에 대한 세그먼트 트리를 만들어 줌으로써 S에서 X번 움직였을 때 T 이상/이하 로 갈 수 있나? 에 대한 결정문제를 쿼리당 $O(log^3N)$에 풀어줄 수 있다.

구간을 관리하는 희소배열 이라고 생각하면 이해하기 편하다.

시간복잡도 $O(N log ^2 N + Q log ^ 3 N)$

 

#include <bits/stdc++.h>
using namespace std;
using pii = pair<int, int>;
using vi = vector<int>;

int INF = 1e9;

struct Seg{
    int st[202020], sz, identity;
    function<int(int, int)> op;
    Seg(){}
    Seg(int _sz, int _identity, function<int(int, int)> _op){ 
        sz = _sz;
        op = _op;
        identity = _identity;
        fill(st, st + 202020, identity);
    }

    void init(){
        for(int i = sz - 1; i > 0; i--){
            st[i] = op(st[i << 1], st[i << 1 | 1]);
        }
    }

    void update(int x, int v){
        x += sz;
        for(st[x] = op(st[x], v); x > 1; x >>= 1) st[x >> 1] = op(st[x], st[x ^ 1]);
    }

    int query(int l, int r){
        int ret = identity;
        for(l += sz, r += sz + 1; l < r; l >>= 1, r >>= 1){
            if(l & 1) ret = op(ret, st[l++]);
            if(r & 1) ret = op(st[--r], ret);
        }
        return ret;
    }
};

class EraseMinPq{
    private:
    priority_queue<int, vector<int>, greater<int>> pq;
    priority_queue<int, vector<int>, greater<int>> bin;

    public:
    void push(int x){
        pq.push(x);
    }

    void del(int x){
        bin.push(x);
    }

    int top(){
        while(!bin.empty() && pq.top() == bin.top()) pq.pop(), bin.pop();
        if(pq.empty()) return INF;
        return pq.top();
    }
};

class EraseMaxPq{
    private:
    priority_queue<int> pq;
    priority_queue<int> bin;

    public:
    void push(int x){
        pq.push(x);
    }

    void del(int x){
        bin.push(x);
    }

    int top(){
        while(!bin.empty() && pq.top() == bin.top()) pq.pop(), bin.pop();
        if(pq.empty()) return 0;
        return pq.top();
    }
};

int l[101010], r[101010], ll[20][101010], rl[20][101010], mx = 20;
int n, k;
vi lin[101010], lout[101010], rin[101010], rout[101010];
Seg lseg[20];
Seg rseg[20];

int main(){
    cin.tie(nullptr)->sync_with_stdio(0);
    
    cin >> n >> k;
    for(int i = 0; i < mx; i++){
        lseg[i] = Seg(n + 1, INF, [](int x, int y){ return min(x, y); });
        rseg[i] = Seg(n + 1, 0, [](int x, int y){ return max(x, y); });
    }

    for(int i = 1; i <= n; i++){
        l[i] = i;
        r[i] = i;
        lseg[0].st [i + n + 1] = i;
        rseg[0].st[i + n + 1] = i;
    }

    int m;
    cin >> m;
    for(int i = 0; i < m; i++){
        int a, b;
        cin >> a >> b;
        if(a < b){
            int e = min(a + k - 1, b - 1);
            rin[a].push_back(b);
            rout[e + 1].push_back(b);
        }
        else{
            int s = max(a - k + 1, b + 1);
            lin[a].push_back(b);
            lout[s - 1].push_back(b);
        }
    }

    EraseMaxPq maxpq;
    for(int i = 1; i <= n; i++){
        for(auto &j: rin[i]){
            maxpq.push(j);
        }
        for(auto &j: rout[i]){
            maxpq.del(j);
        }
        rin[i].clear();
        rout[i].clear();
        r[i] = max(r[i], maxpq.top());
        rseg[0].st[i + n + 1] = max(rseg[0].st[i + n + 1], r[i]);
    }

    EraseMinPq minpq;
    for(int i = n; i >= 1; i--){
        for(auto &j: lin[i]){
            minpq.push(j);
        }
        for(auto &j: lout[i]){
            minpq.del(j);
        }
        lin[i].clear();
        lout[i].clear();
        l[i] = min(l[i], minpq.top());
        lseg[0].st[i + n + 1] = min(lseg[0].st[i + n + 1], l[i]);
    }

    lseg[0].init(); rseg[0].init();

    for(int i = 1; i <= n; i++){
        ll[0][i] = l[i];
        rl[0][i] = r[i];
    }

    for(int i = 1; i < mx; i++){
        for(int j = 1; j <= n; j++){
            ll[i][j] = lseg[i - 1].query(ll[i - 1][j], rl[i - 1][j]);
            rl[i][j] = rseg[i - 1].query(ll[i - 1][j], rl[i - 1][j]);
        }

        for(int j = 1; j <= n; j++){
            lseg[i].st[j + n + 1] = ll[i][j];
            rseg[i].st[j + n + 1] = rl[i][j];
        }
        lseg[i].init();
        rseg[i].init();
    }

    int q;
    cin >> q;
    while(q--){
        int s, t;
        cin >> s >> t;

        if(s < t){

            int lo = -1, hi = n;
            while(lo + 1 < hi){
                int mi = (lo + hi) >> 1;
                int lq = s, rq = s;

                for(int i = 0; i < mx; i++){
                    if(mi & (1 << i)){
                        int tl = lseg[i].query(lq, rq);
                        int tr = rseg[i].query(lq, rq);
                        lq = tl, rq = tr;
                    }
                }

                if(rq >= t) hi = mi;
                else lo = mi;
            }

            int lq = s, rq = s;
            for(int i = 0; i < mx; i++){
                if(hi & (1 << i)){
                    int tl = lseg[i].query(lq, rq);
                    int tr = rseg[i].query(lq, rq);
                    lq = tl, rq = tr;
                }
            }

            if(rq < t) cout << -1 << '\n';
            else cout << hi << '\n';
        }
        else{

            int lo = -1, hi = n;
            while(lo + 1 < hi){
                int mi = (lo + hi) >> 1;
                int lq = s, rq = s;

                for(int i = 0; i < mx; i++){
                    if(mi & (1 << i)){
                        int tl = lseg[i].query(lq, rq);
                        int tr = rseg[i].query(lq, rq);
                        lq = tl, rq = tr;
                    }
                }

                if(lq <= t) hi = mi;
                else lo = mi;
            }

            int lq = s, rq = s;
            for(int i = 0; i < mx; i++){
                if(hi & (1 << i)){
                    int tl = lseg[i].query(lq, rq);
                    int tr = rseg[i].query(lq, rq);
                    lq = tl, rq = tr;
                }
            }
            
            if(lq > t) cout << -1 << '\n';
            else cout << hi << '\n';

        }
    }
}

 

코드가 상당히 긴데, 그냥 내가 더럽게 풀어서 그렇다(....)

 

후기) joi문제는 항상 참신하고 재밌다. joi 짱