정화 코딩
[C++] 물류창고 (백준 28296번) 본문

https://www.acmicpc.net/problem/28296
다음과 같은 순서로 생각하면 이해가 그나마 좀 쉬운 것 같다.
일단 두 창고를 연결하는 경로상의 도로 중 가장 작은 가중치가 그 경로의 상한선이 되고, 두 창고를 연결하는 여러 경로가 있을 때 상한선이 가장 큰 경로를 택한다.
즉, 두 창고를 연결하는 경로가 여러 개 있을 때, 가중치가 작은 도로가 포함되지 않도록 하는 것이 최적이다. 이를 통해 가중치가 작은 간선들은 필요없고 최대 스패닝 트리를 그렸을 때 거기에 포함된 도로들만 사용하면 된다는 것을 알 수 있다.
최대 스패닝 트리를 그려야 하므로 가중치가 큰 도로부터 보면 된다.
가중치가 큰 도로부터 보면 현재 보고 있는 도로가 지금까지 봐온 도로 중 가중치가 가장 작은 도로라는 것이니까, 현재 보고 있는 도로를 기준으로 양끝 창고들은 이 도로를 타야 하고 이 도로의 가중치가 상한선이 된다.
이게 무슨 말인지 그림으로 다시 보자.

따라서, 가중치가 큰 도로부터 보면서 추가를 하는데, 각 그룹 별로 각 회사에 속한 창고가 몇개씩 있는지 저장해두고 계산에 사용하면 된다. 해당 도로를 추가하므로써 추가되는 상한선은 그 도로의 가중치 * 그 도로 왼쪽의 창고 개수 * 그 도로 오른쪽의 창고 개수이고, 이걸 회사 별로 계산해주면 된다.
#include <iostream>
#include <vector>
#include <algorithm>
#include <unordered_map>
using namespace std;
int n, m, k;
vector<int> c;
vector<vector<int>> e;
vector<int> parent;
vector<unordered_map<int, int>> setsize;
vector<long long> ans;
int find(int a) {
if (parent[a] == a) return a;
else return parent[a] = find(parent[a]);
}
bool unite(int w, int a, int b) {
a = find(a);
b = find(b);
if (a == b) return false;
if (a > b) swap(a, b);
parent[b] = a;
for (int i = 1; i < k + 1; i++) { // 각 회사 별로
if (setsize[b].count(i)) {
if (setsize[a].count(i)) {
ans[i] += (long long) w * setsize[a][i] * setsize[b][i];
}
setsize[a][i] += setsize[b][i];
}
}
return true;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> k >> m;
parent = vector<int>(n + 1);
setsize = vector<unordered_map<int, int>>(n + 1);
for (int i = 0; i < n + 1; i++) parent[i] = i;
c = vector<int>(n + 1);
for (int i = 1; i < n + 1; i++) {
cin >> c[i];
setsize[i][c[i]] = 1;
}
e = vector<vector<int>>(m, vector<int>(3));
for (int i = 0; i < m; i++) {
cin >> e[i][1] >> e[i][2] >> e[i][0];
}
sort(e.begin(), e.end(), greater<>());
ans = vector<long long>(k + 1, 0);
int cnt = 0;
for (int i = 0; i < m; i++) {
if (unite(e[i][0], e[i][1], e[i][2])) cnt++;
if (cnt >= n - 1) break;
}
for (int i = 1; i < k + 1; i++) {
cout << ans[i] << '\n';
}
}
처음에 이렇게 제출해서 시간초과를 받았다. (TLE)
#include <iostream>
#include <vector>
#include <algorithm>
#include <unordered_map>
using namespace std;
int n, m, k;
vector<int> c;
vector<vector<int>> e;
vector<int> parent;
vector<unordered_map<int, int>> setsize;
vector<long long> ans;
int find(int a) {
if (parent[a] == a) return a;
else return parent[a] = find(parent[a]);
}
bool unite(int w, int a, int b) {
a = find(a);
b = find(b);
if (a == b) return false;
// 작은 쪽(b)에서 큰 쪽(a)으로 합치기 (small to large)
if (setsize[a].size() < setsize[b].size()) swap(a, b);
parent[b] = a;
for (auto [cp, v]: setsize[b]) {
if (setsize[a].count(cp)) {
ans[cp] += (long long) w * setsize[a][cp] * v;
}
setsize[a][cp] += setsize[b][cp]; // a에 회사 cp에 속하는 창고가 없어도 해주어야 함
}
return true;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> k >> m;
parent = vector<int>(n + 1);
setsize = vector<unordered_map<int, int>>(n + 1);
for (int i = 0; i < n + 1; i++) parent[i] = i;
c = vector<int>(n + 1);
for (int i = 1; i < n + 1; i++) {
cin >> c[i];
setsize[i][c[i]] = 1;
}
e = vector<vector<int>>(m, vector<int>(3));
for (int i = 0; i < m; i++) {
cin >> e[i][1] >> e[i][2] >> e[i][0];
}
sort(e.begin(), e.end(), greater<>());
ans = vector<long long>(k + 1, 0);
int cnt = 0;
for (int i = 0; i < m; i++) {
if (unite(e[i][0], e[i][1], e[i][2])) cnt++;
if (cnt >= n - 1) break;
}
for (int i = 1; i < k + 1; i++) {
cout << ans[i] << '\n';
}
}
이렇게 수정하여 정답을 받을 수 있었다. (AC)
수정한 부분 1. 모든 회사에 대해서 보는 것이 아니라 그룹 b에 있는 회사만 보기
// 수정 전
for (int i = 1; i < k + 1; i++) { // 각 회사 별로
if (setsize[b].count(i)) {
if (setsize[a].count(i)) {
ans[i] += (long long) w * setsize[a][i] * setsize[b][i];
}
setsize[a][i] += setsize[b][i];
}
}
// 수정 후
for (auto [cp, v]: setsize[b]) {
if (setsize[a].count(cp)) {
ans[cp] += (long long) w * setsize[a][cp] * v;
}
setsize[a][cp] += setsize[b][cp]; // a에 회사 cp에 속하는 창고가 없어도 해주어야 함
}
수정한 부분 2. 더 작은 그룹에서 더 큰 그룹으로 합치기
// 수정 전
if (a > b) swap(a, b);
parent[b] = a;
// 수정 후
if (setsize[a].size() < setsize[b].size()) swap(a, b);
parent[b] = a;
이 문제 풀면서 unordered_map에 대해서 새롭게 알게된 점!!
- m.count(key) : m에 key가 존재하면 1 반환, 아니면 0 반환
- m[key]를 사용하면 m에 key가 없어도 자동으로 m[key]를 선언하고 0으로 초기화한 후 연산을 수행한다.
'PS' 카테고리의 다른 글
[C++] 육각형 우리 속의 개미 (백준 17370번) (0) | 2025.05.29 |
---|---|
[C++] 차량 모듈 제작 (백준 28297번) (0) | 2025.05.26 |
[C++] 사이클 게임 (백준 20040번) (1) | 2025.05.19 |
[C++] 응원단 (백준 28300번) (0) | 2025.05.19 |
[C++] ACM Craft (백준 1005번) (0) | 2025.05.06 |