행렬의 곱셈

문제 링크

행렬의 곱셈

분석

행렬의 곱셈은 다음과 같습니다.

\[\begin{bmatrix} a & b \newline c & d \end{bmatrix} \times \begin{bmatrix} x \newline y \end{bmatrix} = \begin{bmatrix} (a \cdot x) + (b \cdot y) \newline (c \cdot x) + (d \cdot y) \end{bmatrix}\]

입출력 예의 첫번째 행렬의 곱셈을 계산해보면 다음과 같습니다.

첫 번째 행
$C_{11} = (1 \times 3) + (4 \times 3) = 3 + 12 = 15$
$C_{12} = (1 \times 3) + (4 \times 3) = 3 + 12 = 15$

두 번째 행
$C_{21} = (3 \times 3) + (2 \times 3) = 9 + 6 = 15$
$C_{22} = (3 \times 3) + (2 \times 3) = 9 + 6 = 15$

세 번째 행
$C_{31} = (4 \times 3) + (1 \times 3) = 12 + 3 = 15$
$C_{32} = (4 \times 3) + (1 \times 3) = 12 + 3 = 15$

입출력 예의 두번째 행렬의 곱셈을 계산해보면 다음과 같습니다.

첫 번째 행
$C_{11} = (2 \times 5) + (3 \times 2) + (2 \times 3) = 10 + 6 + 6 = 22$
$C_{12} = (2 \times 4) + (3 \times 4) + (2 \times 1) = 8 + 12 + 2 = 22$
$C_{13} = (2 \times 3) + (3 \times 1) + (2 \times 1) = 6 + 3 + 2 = 11$

두 번째 행
$C_{21} = (4 \times 5) + (2 \times 2) + (4 \times 3) = 20 + 4 + 12 = 36$ $C_{22} = (4 \times 4) + (2 \times 4) + (4 \times 1) = 16 + 8 + 4 = 28$
$C_{23} = (4 \times 3) + (2 \times 1) + (4 \times 1) = 12 + 2 + 4 = 18$

세 번째 행
$C_{31} = (3 \times 5) + (1 \times 2) + (4 \times 3) = 15 + 2 + 12 = 29$
$C_{32} = (3 \times 4) + (1 \times 4) + (4 \times 1) = 12 + 4 + 4 = 20$
$C_{33} = (3 \times 3) + (1 \times 1) + (4 \times 1) = 9 + 1 + 4 = 14$

풀이

#include <vector>

using namespace std;

vector<vector<int>> solution(vector<vector<int>> arr1, vector<vector<int>> arr2) {
    vector<vector<int>> answer(arr1.size(), vector<int>(arr2[0].size(), 0));
    
    // arr1 행렬의 행 순회
    for (int i = 0; i < arr1.size(); ++i)
    {
        // arr2 행렬의 열 순회
        for (int j = 0; j < arr2[0].size(); ++j)
        {
            // arr1 행렬의 열 순회
            for(int k = 0; k < arr1[0].size(); ++k)
            {
                answer[i][j] += arr1[i][k] * arr2[k][j];
            }
        }
    }
    
    return answer;
}

성능 요약

시간 복잡도는 $O(n \times p \times m)$입니다.

  • arr1의 행 순회 $O(n)$
  • arr2의 열 순회 $O(p)$
  • arr1의 열 순회 $O(m)$
  • $O(n \times p \times m)$

공간 복잡도는 $O(n \times m)$입니다.

  • 결과 값을 저장하는 2차원 배열 answer $O(n \times p)$

테스트 1 〉 통과 (0.13ms, 4.14MB)
테스트 2 〉 통과 (1.59ms, 5.36MB)
테스트 3 〉 통과 (3.28ms, 5.73MB)
테스트 4 〉 통과 (0.07ms, 4.17MB)
테스트 5 〉 통과 (1.44ms, 5.33MB)
테스트 6 〉 통과 (1.61ms, 5.02MB)
테스트 7 〉 통과 (0.11ms, 4.26MB)
테스트 8 〉 통과 (0.10ms, 4.15MB)
테스트 9 〉 통과 (0.09ms, 4.14MB)
테스트 10 〉 통과 (1.07ms, 4.84MB)
테스트 11 〉 통과 (0.39ms, 4.15MB)
테스트 12 〉 통과 (0.26ms, 4.17MB)
테스트 13 〉 통과 (1.26ms, 5MB)
테스트 14 〉 통과 (1.60ms, 5.04MB)
테스트 15 〉 통과 (0.63ms, 4.44MB)
테스트 16 〉 통과 (1.29ms, 4.64MB)

Programmers 카테고리 내 다른 글 보러가기

댓글남기기