[Baekjoon,11049] 행렬 곱셈 순서

1 분 소요

모든 구간 [0,N-1]에 대해 최소 비용을 얻기 위해 문제를 잘게 쪼개야 한다. 이를 위한 방법은 각 구간에 대해 재귀적으로 중간을 나누어 쪼개는 것이다. 이는 아래와 같은 부분 문제를 정의하면 쉽게 구현 할 수 있다.

dp(lo,hi) = 구간 [lo,hi] 에서 행렬 곱셈을 최소화 하는 비용 

이때 중간으로 쪼개 진 두 구간 [lo,mid], [mid+1,hi] 가 있을 때 이 구간을 합치는 비용을 더해야 구간에 대한 총 비용이 나온다. 행렬은 연결된 부분의 너비가 같으므로 이를 이용하여 두 구간의 끝과 겹쳐지는 가운데 부분을 곱하면 비용을 구할 수 있다.

예를 들어 아래와 같은 네 개의 행렬이 있다고 해보자.

pos  :   0     1     2     3
(r,c): (4,5) (5,3) (3,2) (2,6)

mid를 1로 설정하여 [0,mid]을 합치고 [mid+1,3]을 합치면 아래와 같은 비용이 발생한다.

pos  :   01     23
(r,c):  (4,3)  (3,6)

cost : 4*5*3 + 3*2*6

이 다음 다시 합쳐진 두 행렬을 합쳐보자.

pos  :  0123
(r,c):  (4,6)

cost : 4*5*3 + 3*2*6 + 4*3*6

여기서 볼 수 있듯이 mid의 c 값과 양 끝 값을 곱하면 이미 곱한 행렬을 쉽게 곱할 수 있다는 것을 알 수 있다.

구현

#include <iostream>
#include <vector>
#include <cstring>
using namespace std;

const int INF=987654321;
const int MAX=500;

int N;
vector<pair<int,int> > mat;
int cache[MAX][MAX];

int dp(int lo, int hi)
{
    // 기저 사례: 행렬이 하나인 경우 
    if(lo==hi)
        return 0;
    
    // 메모이제이션
    int& ret=cache[lo][hi];
    if(ret!=-1)
        return ret;
    
    ret=INF;
    // 구간 별로 나누어 모든 곱셈의 순서를 고려한다. 
    for(int mid=lo;mid<hi;++mid)
    {
        int left=dp(lo,mid);
        int right=dp(mid+1,hi);
        ret=min(ret,left+right+mat[lo].first*mat[mid].second*mat[hi].second);
    }
    return ret;
}

int main()
{
    cin>>N;
    for(int i=0;i<N;++i)
    {
        int a,b;
        cin>>a>>b;
        mat.push_back(make_pair(a,b));
    }
    memset(cache,-1,sizeof(cache));
    
    cout<<dp(0,N-1)<<endl;
    
    return 0;
}

댓글남기기