정화 코딩

EDOC 2023-W 4회차 과제 (플로이드-워셜 / 최소 신장 트리) 본문

Group/EDOC

EDOC 2023-W 4회차 과제 (플로이드-워셜 / 최소 신장 트리)

jungh150c 2024. 2. 1. 03:18

08-6. 플로이드-워셜

 

061. 플로이드 (백준 11404번)

 

https://www.acmicpc.net/problem/11404

 

from sys import stdin
import sys

n = int(stdin.readline())
m = int(stdin.readline())
g = [[sys.maxsize for _ in range(n+1)] for _ in range(n+1)]

for i in range(n+1):
    g[i][i] = 0

for _ in range(m):
    a, b, c = map(int, stdin.readline().split())
    if c < g[a][b]:
        g[a][b] = c

for i in range(1, n+1): # 경유지 i에 관해
    for j in range(1, n+1): # 출발 노드 j 에 관해
        for k in range(1, n+1): # 도착 노드 k에 관해
            if g[j][i] + g[i][k] < g[j][k]:
                g[j][k] = g[j][i] + g[i][k]

for i in range(1, n+1):
    for j in range(1, n+1):
        if g[i][j] == sys.maxsize:
            print(0, end=' ')
        else:
            print(g[i][j], end=' ')
    print()

처음에는 시작 도시와 도착 도시를 연결하는 노선은 하나가 아닐 수 있다는 조건을 신경쓰지 않은채 코드를 짜서 if c < g[a][b]: 이 부분 없이 돌려봤다. 그런데 예제도 틀리길래 생각 후 이 부분을 추가하여 제출했다. (정답)

 


 

062. 경로 찾기 (백준 11403번)

 

https://www.acmicpc.net/problem/11403

 

from sys import stdin

n = int(stdin.readline())
g = []

for _ in range(n):
    g.append(list(map(int, stdin.readline().split())))

for i in range(n): # 경유지 i에 관해
    for j in range(n): # 출발 노드 j 에 관해
        for k in range(n): # 도착 노드 k에 관해
            if g[j][i] == 1 and g[i][k] == 1:
                g[j][k] = 1

for i in range(n):
    for j in range(n):
        print(g[i][j], end=' ')
    print()

플로이드-워셜 알고리즘을 살짝만 변형시키면 되는 문제였다. (정답)

 


 

063. 케빈 베이컨의 6단계 법칙 (백준 1389번)

 

https://www.acmicpc.net/problem/1389

 

from sys import stdin
import sys

n, m = map(int, stdin.readline().split())
g = [[sys.maxsize for _ in range(n+1)] for _ in range(n+1)]

for i in range(1, n+1):
    g[i][i] = 0

for _ in range(m):
    a, b = map(int, stdin.readline().split())
    g[a][b] = 1
    g[b][a] = 1

for i in range(1, n+1):
    for j in range(1, n+1):
        for k in range(1, n+1):
            if g[j][i] + g[i][k] < g[j][k]:
                g[j][k] = g[j][i] + g[i][k]

ans = 1
min = sum(g[1])
for i in range(2, n+1):
    if sum(g[i]) < min:
        min = sum(g[i])
        ans = i

print(ans)

이 문제는 63번에서 사용한 플로이드-워셜 알고리즘의 기본 형태를 그대로 사용하면 되는 문제였다. (정답)

 


08-7. 최소 신장 트리

 

064. 최소 스패닝 트리 (백준 1197번)

 

https://www.acmicpc.net/problem/1197

 

from sys import stdin

v, e = map(int, stdin.readline().split())
edge = []
parent = [i for i in range(v+1)]

def find(a):
    if parent[a] == a:
        return a
    else:
        parent[a] = find(parent[a])
        return parent[a]

def union(a, b):
    a = find(a)
    b = find(b)
    if a != b:
        parent[b] = a

for _ in range(e):
    a, b, c = map(int, stdin.readline().split())
    edge.append((a, b, c))

edge.sort(key = lambda x : x[2]) # 세번째 값을 기준으로 정렬

usedEdge = 0
res = 0

for x in edge:
    if find(x[0]) != find(x[1]):
        union(x[0], x[1])
        usedEdge += 1
        res += x[2]
    if usedEdge >= v-1:
        break

print(res)

교재에서는 에지 배열을 우선순위 큐로 만들어 정렬을 하고 거기서 하나씩 빼가면서 MST 알고리즘을 진행하였다. 나는 살짝 다르게 에지 배열을 리스트로 만들어 세번째 값인 가중치를 기준으로 정렬을 하고 원소를 하나씩 보면서 MST 알고리즘을 진행하는 방식으로 코드를 짰다. (정답)

 


 

065. 다리 만들기 2 (백준 17472번)

 

https://www.acmicpc.net/problem/17472

 

from sys import stdin
from collections import deque

dr = [0, 1, 0, -1]
dc = [1, 0, -1, 0]
    # 상 우 하 좌

n, m = map(int, stdin.readline().split())
data = [] # 지도
visited = [[False for _ in range(m)] for _ in range(n)]
edge = []
ans = 0
iNum = 1 # 섬 번호 넘버링

for _ in range(n):
    data.append(list(map(int, stdin.readline().split())))

def bfs(i, j):
    que = deque()
    que.append([i, j])
    visited[i][j] = True
    data[i][j] = iNum
    while que:
        r, c = que.popleft()
        for i in range(4):
            nr = r + dr[i]
            nc = c + dc[i]
            if nr >= 0 and nr < n and nc >= 0 and nc < m:
                if not visited[nr][nc] and data[nr][nc] != 0:
                    visited[nr][nc] = True
                    data[nr][nc] = iNum
                    que.append([nr, nc])

def find(a):
    if parent[a] == a:
        return a
    else:
        parent[a] = find(parent[a])
        return parent[a]

def union(a, b):
    a = find(a)
    b = find(b)
    if a != b:
        parent[b] = a
        return True
    else:
        return False

for i in range(n):
    for j in range(m):
        if not visited[i][j] and data[i][j] != 0:
            bfs(i, j)
            iNum += 1

parent = [i for i in range(iNum)]

for i in range(n):
    pre = -1 # 전 column 값
    for j in range(m):
        if data[i][j] != 0:
            if pre != -1 and data[i][pre] != data[i][j] and (j - pre - 1) >=2:
                edge.append((data[i][pre], data[i][j], (j - pre - 1)))
            pre = j

for j in range(m):
    pre = -1 # 전 row 값
    for i in range(n):
        if data[i][j] != 0:
            if pre != -1 and data[pre][j] != data[i][j] and (i - pre - 1) >=2:
                edge.append((data[pre][j], data[i][j], (i - pre - 1)))
            pre = i

edge.sort(key = lambda x : x[2])

for x in edge:
    if union(x[0], x[1]):
        ans += x[2]

isPossible = True
tmp = find(1)
for i in range(2, iNum):
    if find(i) != tmp:
        isPossible = False
        break

if isPossible:
    print(ans)
else:
    print(-1)

엄청 이것저것 할 게 많고 과정이 많아서 힘들었던 문제. 우선 섬을 그룹화해야하는데 이중 for문을 돌면서 0이 아닌 점을 만나면 bfs를 수행해서 같은 섬끼리 같은 번호를 갖도록 하였다. 그 다음에 이중 for문을 두번 돌면서 각각 가로 도로와 세로 도로를 체크했다. 첫번째로 섬을 만나면 pre값만 바꿔주고 두번째 이후로 섬을 다시 만나면 pre에 저장해뒀던 전 섬까지의 거리를 구해서 거리가 2 이상이면 에지 배열에 넣도록 하였다. 마지막으로 프림 알고리즘을 이용해서 최소 신장 트리를 구했다. 처음에는 모든 섬이 연결하는 것이 불가능하면 -1을 출력해야 한다는 조건을 고려하지 못했다. 그 후 isPossible 변수를 이용해 체크하고 출력하도록 했다. (정답)

 

여기서 팁 한 가지! 

if find(x[0]) != find(x[1]):
	union(x[0], x[1])
	ans += x[2]

위의 세 줄을

if union(x[0], x[1]):
	ans += x[2]

이렇게 두 줄로 압축시키려면

def union(a, b):
    a = find(a)
    b = find(b)
    if a != b:
        parent[b] = a
        return True ###
    else:
        return False ###

이렇게 union 함수가 불리언 값을 리턴하도록 하면 된다.

 


 

066. 불우이웃돕기 (백준 1414번)

 

https://www.acmicpc.net/problem/1414

 

from sys import stdin

n = int(stdin.readline())
edge = []
parent = [i for i in range(n+1)]
sum = 0

def find(a):
    if parent[a] == a:
        return a
    else:
        parent[a] = find(parent[a])
        return parent[a]

def union(a, b):
    a = find(a)
    b = find(b)
    if a != b:
        parent[b] = a
        return True
    else:
        return False

for i in range(n):
    input = stdin.readline()
    for j in range(n):
        num = ord(input[j])
        if num != 48: # 0이 아닐 때만 노드 추가
            if num >= 97: # 알파벳 소문자일 경우
                w = num - 96
                sum += w
            else: # 알파벳이 대문자일 경우
                w = num - 38
                sum += w
            edge.append([i+1, j+1, w])

edge.sort(key = lambda x : x[2])

for x in edge:
    if find(x[0]) != find(x[1]):
        union(x[0], x[1])
        sum -= x[2]

isPossible = True
tmp = find(1)
for i in range(2, n+1):
    if find(i) != tmp:
        isPossible = False
        break

if isPossible:
    print(sum)
else:
    print(-1)

비교적 간단했던 문제. 전형적인 최소 스패닝 트리 문제였다. (정답)

 

Comments