코딩테스트/그래프(MST(쿠루스칼,프림),위상정렬)

[그래프] MST(쿠루스칼, 프림) 이론

영최 2024. 9. 4. 20:00
728x90

0️⃣  MST( 최소신장트리, Minimum Spanning Tree)

정의: 무방향 그래프에서 최소의 비용으로 모든 노드를 연결하면서 사이클이 발생하지 않는 트리

사용되는 알고리즘: 크루스칼, 프림 

간적쿠간만프: 간선 적으면 크루스칼 간선이 많으면 프림 사용

1️⃣  크루스칼

 언제 사용? 간선이 을때

✅ 시간복잡도(개선된) :  O(ElogE) (100만 정도) - 간선 정렬

✅ 원리 :

1) 간선 데이터를 비용에 따라 오름차순으로 정렬 

2) 사이클이 발생하지 않는 경우 최소신장 트리에 포함시킴 (반복)

✅ 서로소 집합 부모 찾기 최적화:

parent[x] = find_parent(parent, parent[x]): 처음에는 parent[x] != x에서 parent[x] == x 까지 재귀를 돌면서
return parent[x]를 통해 업데이트한다. 그 후에는 parent[x] != x인 경우 바로 parent[x]로 넘어가므로 재귀가 없다.(이미 업데이트 된 상태여서)

# 특정 원소가 속한 집합을 찾기
def find_parent(parent, x):
    # 루트 노드가 아니라면, 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] != x:
        parent[x] = find_parent(parent, parent[x])
    return parent[x]

✅ 사이클 판별이 필요한 경우(변형)

if find_parent(parent, a) == find_parent(parent, b):
	cycle = True
    break

 기본 코드:

# 특정 원소가 속한 집합을 찾기
def find_parent(parent, x):
    # 루트 노드가 아니라면, 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] != x:
        parent[x] = find_parent(parent, parent[x])
    return parent[x]

# 두 원소가 속한 집합을 합치기
def union_parent(parent, a, b):
    a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b

# 노드의 개수와 간선(Union 연산)의 개수 입력 받기
v, e = map(int, input().split())
parent = [0] * (v + 1) # 부모 테이블 초기화하기

# 모든 간선을 담을 리스트와, 최종 비용을 담을 변수
edges = []
result = 0

# 부모 테이블상에서, 부모를 자기 자신으로 초기화
for i in range(1, v + 1):
    parent[i] = i

# 모든 간선에 대한 정보를 입력 받기
for _ in range(e):
    a, b, cost = map(int, input().split())
    # 비용순으로 정렬하기 위해서 튜플의 첫 번째 원소를 비용으로 설정
    edges.append((cost, a, b))

# 간선을 비용순으로 정렬
edges.sort()

# 간선을 하나씩 확인하며
for edge in edges:
    cost, a, b = edge
    # 사이클이 발생하지 않는 경우에만 집합에 포함
    if find_parent(parent, a) != find_parent(parent, b):
        union_parent(parent, a, b)
        result += cost

print(result)

2️⃣  프림

 언제 사용? 간선이 을때

✅ 시간복잡도(개선된) :  O(ElogV) (100만 정도) 

✅ 원리 : (다익스트라와 유사)

1) 임의의 정점을 시작으로 함
2) 해당 정점과 인접한 정점을 힙에 넣음

3) 가장 작은 비용을 가진 정점 빼서 연결 

4) 2)->3)을 전체 정점 연결을 다할때까지 반복(V-1)

 기본 코드:

from heapq

V, E = map(int, input().split())
G = [[] for _ in range(V + 1)]
for _ in range(E):
    u, v, w = map(int, input().split())
    G[u].append((v, w)) #무향 그래프
    G[v].append((u, w))
    
def prim(start, weight):
    visit = [0] * (V + 1) # 정점 방문 처리
    q = [] # 힙 구조를 사용하기 위해 가중치를 앞에 둠
    heapq.heappush(q,(weight, start))
    ans = 0 # 가중치 합
    cnt = 0 # 간선의 개수
    while cnt<V: # 간선의 개수 최대는 V-1
        w, v = heapq.heappop(q)
        if visit[v]: continue # 이미 방문한 정점이면 지나감
        visit[v] = 1 # 방문안했으면 방문처리
        ans += w # 해당 정점까지의 가중치를 더해줌
        cnt += 1 # 간선의 갯수 더해줌
        for next_v, next_w in G[v]: # 해당 정점의 간선정보를 불러옴
            heapq.heappush(q, (next_w, next_v)) # 힙에 넣어줌
    return ans



print(prim(1, 0))

 

 

728x90