정화 코딩

[C++] 물류창고 (백준 28296번) 본문

PS

[C++] 물류창고 (백준 28296번)

jungh150c 2025. 5. 20. 02:26

https://www.acmicpc.net/problem/28296

 

다음과 같은 순서로 생각하면 이해가 그나마 좀 쉬운 것 같다.

 

일단 두 창고를 연결하는 경로상의 도로 중 가장 작은 가중치가 그 경로의 상한선이 되고, 두 창고를 연결하는 여러 경로가 있을 때 상한선이 가장 큰 경로를 택한다.

즉, 두 창고를 연결하는 경로가 여러 개 있을 때, 가중치가 작은 도로가 포함되지 않도록 하는 것이 최적이다. 이를 통해 가중치가 작은 간선들은 필요없고 최대 스패닝 트리를 그렸을 때 거기에 포함된 도로들만 사용하면 된다는 것을 알 수 있다. 

 

최대 스패닝 트리를 그려야 하므로 가중치가 큰 도로부터 보면 된다. 

가중치가 큰 도로부터 보면 현재 보고 있는 도로가 지금까지 봐온 도로 중 가중치가 가장 작은 도로라는 것이니까, 현재 보고 있는 도로를 기준으로 양끝 창고들은 이 도로를 타야 하고 이 도로의 가중치가 상한선이 된다.

이게 무슨 말인지 그림으로 다시 보자.

 

따라서, 가중치가 큰 도로부터 보면서 추가를 하는데, 각 그룹 별로 각 회사에 속한 창고가 몇개씩 있는지 저장해두고 계산에 사용하면 된다. 해당 도로를 추가하므로써 추가되는 상한선그 도로의 가중치 * 그 도로 왼쪽의 창고 개수 * 그 도로 오른쪽의 창고 개수이고, 이걸 회사 별로 계산해주면 된다. 

 

#include <iostream>
#include <vector>
#include <algorithm>
#include <unordered_map>
using namespace std;

int n, m, k;
vector<int> c;
vector<vector<int>> e;
vector<int> parent;
vector<unordered_map<int, int>> setsize;
vector<long long> ans;

int find(int a) {
    if (parent[a] == a) return a;
    else return parent[a] = find(parent[a]);
}

bool unite(int w, int a, int b) {
    a = find(a);
    b = find(b);
    if (a == b) return false;
    if (a > b) swap(a, b);
    parent[b] = a;
    for (int i = 1; i < k + 1; i++) { // 각 회사 별로
        if (setsize[b].count(i)) {
            if (setsize[a].count(i)) {
                ans[i] += (long long) w * setsize[a][i] * setsize[b][i];
            }
            setsize[a][i] += setsize[b][i];
        }
    }
    return true;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> k >> m;

    parent = vector<int>(n + 1);
    setsize = vector<unordered_map<int, int>>(n + 1);
    for (int i = 0; i < n + 1; i++) parent[i] = i;

    c = vector<int>(n + 1);
    for (int i = 1; i < n + 1; i++) {
        cin >> c[i];
        setsize[i][c[i]] = 1;
    }

    e = vector<vector<int>>(m, vector<int>(3));
    for (int i = 0; i < m; i++) {
        cin >> e[i][1] >> e[i][2] >> e[i][0];
    }

    sort(e.begin(), e.end(), greater<>());

    ans = vector<long long>(k + 1, 0);

    int cnt = 0;
    for (int i = 0; i < m; i++) {
        if (unite(e[i][0], e[i][1], e[i][2])) cnt++;
        if (cnt >= n - 1) break;
    }

    for (int i = 1; i < k + 1; i++) {
        cout << ans[i] << '\n';
    }
}

처음에 이렇게 제출해서 시간초과를 받았다. (TLE)

#include <iostream>
#include <vector>
#include <algorithm>
#include <unordered_map>
using namespace std;

int n, m, k;
vector<int> c;
vector<vector<int>> e;
vector<int> parent;
vector<unordered_map<int, int>> setsize;
vector<long long> ans;

int find(int a) {
    if (parent[a] == a) return a;
    else return parent[a] = find(parent[a]);
}

bool unite(int w, int a, int b) {
    a = find(a);
    b = find(b);
    if (a == b) return false;

    // 작은 쪽(b)에서 큰 쪽(a)으로 합치기 (small to large)
    if (setsize[a].size() < setsize[b].size()) swap(a, b);
    parent[b] = a;

    for (auto [cp, v]: setsize[b]) {
        if (setsize[a].count(cp)) {
            ans[cp] += (long long) w * setsize[a][cp] * v;
        }
        setsize[a][cp] += setsize[b][cp]; // a에 회사 cp에 속하는 창고가 없어도 해주어야 함
    }

    return true;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n >> k >> m;

    parent = vector<int>(n + 1);
    setsize = vector<unordered_map<int, int>>(n + 1);
    for (int i = 0; i < n + 1; i++) parent[i] = i;

    c = vector<int>(n + 1);
    for (int i = 1; i < n + 1; i++) {
        cin >> c[i];
        setsize[i][c[i]] = 1;
    }

    e = vector<vector<int>>(m, vector<int>(3));
    for (int i = 0; i < m; i++) {
        cin >> e[i][1] >> e[i][2] >> e[i][0];
    }

    sort(e.begin(), e.end(), greater<>());

    ans = vector<long long>(k + 1, 0);

    int cnt = 0;
    for (int i = 0; i < m; i++) {
        if (unite(e[i][0], e[i][1], e[i][2])) cnt++;
        if (cnt >= n - 1) break;
    }

    for (int i = 1; i < k + 1; i++) {
        cout << ans[i] << '\n';
    }
}

이렇게 수정하여 정답을 받을 수 있었다. (AC)

 

수정한 부분 1. 모든 회사에 대해서 보는 것이 아니라 그룹 b에 있는 회사만 보기

// 수정 전
for (int i = 1; i < k + 1; i++) { // 각 회사 별로
        if (setsize[b].count(i)) {
            if (setsize[a].count(i)) {
                ans[i] += (long long) w * setsize[a][i] * setsize[b][i];
            }
            setsize[a][i] += setsize[b][i];
        }
    }

// 수정 후
    for (auto [cp, v]: setsize[b]) {
        if (setsize[a].count(cp)) {
            ans[cp] += (long long) w * setsize[a][cp] * v;
        }
        setsize[a][cp] += setsize[b][cp]; // a에 회사 cp에 속하는 창고가 없어도 해주어야 함
    }

수정한 부분 2. 더 작은 그룹에서 더 큰 그룹으로 합치기

// 수정 전
    if (a > b) swap(a, b);
    parent[b] = a;

// 수정 후
    if (setsize[a].size() < setsize[b].size()) swap(a, b);
    parent[b] = a;

 

이 문제 풀면서 unordered_map에 대해서 새롭게 알게된 점!!

- m.count(key) : m에 key가 존재하면 1 반환, 아니면 0 반환

- m[key]를 사용하면 m에 key가 없어도 자동으로 m[key]를 선언하고 0으로 초기화한 후 연산을 수행한다.

 

Comments