[Baekjoon,2887] 행성 터널

1 분 소요

자칫 잘못 생각하면 모든 행성을 서로 이어 간선을 형성하고 이 간선들로 최소 스패닝 트리를 찾아 해결할 것이다. 하지만 이러한 방법은 간선이 너무 많아 제한 시간 내에 해결할 수 없다.

// 너무 많은 간선들로 인해 모든 정점을 탐색하다 시간초과가 난다. 
int V;
int p[MAX_V][3];
long long dist(int a, int b);
long long solve()
{
    vector<bool> added(V,false);
    vector<long long> minDist(V,INF);
    long long ret=0;
    minDist[0]=0;
    for(int it=0;it<V;++it)
    {
        int u=-1;
        for(int v=0;v<V;++v)
            if(!added[v] && (u==-1 || minDist[u]>minDist[v]))
                u=v;
        
        ret+=minDist[u];
        added[u]=true;
        for(int v=0;v<V;++v)
        {
            if(added[v] || u==v)
                continue;
            minDist[v]=min(minDist[v],dist(u,v));
        }
    }
    return ret;
}

해결

여기서 주목할 것은 두 행성을 연결하는 비용을 구하는 방법이다. 좌표상의 거리가 아닌 $min(|x_A-x_B|, |y_A-y_B|, |z_A-z_B|)$를 비용으로 두고 있다. 이는 $x,y,z$ 좌표 값들 기준으로 정렬하여 어느 한 행성이 좌표값으로 인접한 행성만을 탐색하여 최소 스패닝 트리를 형성할 수 있다는 것이다. 즉, $x,y,z$ 좌표 값들을 각각 따로 정렬하여 인접한 정점만 접근하면 된다.

구현

이를 구현하면 아래와 같다. 크루스칼 알고리즘을 이용하여 해결하였다. 이때 탐색할 간선 목록에 모든 간선을 추가하지 않고 인접한 정점을 잇는 간선만 추가하는 것을 볼 수 있다.

#include <iostream>
#include <vector>
#include <climits>
#include <cmath>
#include <algorithm>
using namespace std;

const int MAX_V=100000;
const long long INF=LONG_MAX;

struct DisjointSet
{
    vector<int> parent,rank;
    DisjointSet(int v) : parent(v), rank(v,1)
    {
        for(int i=0;i<v;++i)
            parent[i]=i;
    }
    int find(int u)
    {
        if(parent[u]==u)
            return u;
        return parent[u]=find(parent[u]);
    }
    bool merge(int u, int v)
    {
        u=find(u); v=find(v);
        if(u==v) return false;
        // u의 트리가 rank가 더 깊게 한다. 
        if(rank[u]<rank[v]) swap(u,v);
        parent[u]=v;
        // 두 트리의 rank가 같은 경우 
        if(rank[u]==rank[v])
            ++rank[u];
        return true;
    }
};

int V;
// [x or y or z] <행성 위치, 행성 번호>
vector<pair<int,int> > p[3];

int solve()
{
    // <간선의 가중치, <u,v> >
    vector<pair<int,pair<int,int> > > edges;
    // 가장 가까운 정점을 잇는 간선만 추가
    for(int u=0;u<V-1;++u)
    {
        int v=u+1;
        for(int pos=0;pos<3;++pos)
        {
            int here=p[pos][u].second;
            int there=p[pos][v].second;
            int dist=abs(p[pos][u].first-p[pos][v].first);
            edges.push_back(make_pair(dist,make_pair(here,there)));
        }
    }
    sort(edges.begin(),edges.end());
    
    DisjointSet set(V);
    int ret=0;
    for(int i=0;i<edges.size();++i)
    {
        int cost=edges[i].first;
        int u=edges[i].second.first;
        int v=edges[i].second.second;
        if(set.merge(u,v))
            ret+=cost;
    }
    return ret;
}

int main()
{
    cin>>V;
    for(int i=0;i<V;++i)
    {
        int a,b,c;
        scanf("%d %d %d",&a,&b,&c);
        p[0].push_back(make_pair(a,i));
        p[1].push_back(make_pair(b,i));
        p[2].push_back(make_pair(c,i));
    }
    // x,y,z 값 별로 따로 정렬한다. 
    for(int i=0;i<3;++i)
        sort(p[i].begin(),p[i].end());
    
    cout<<solve()<<endl;
    
    return 0;
}

댓글남기기