문제
풀이
문제 이름에도 나와 있듯 MST(최소 신장 트리)를 사용하는 문제이다.
문제에서 제시한 MST의 조건은,
- 스패닝 트리를 구성하는 간선의 개수는 N-1개이다.
- 그래프의 임의의 두 정점을 골랐을 때, 스패닝 트리에 속하는 간선만 이용해서 두 정점을 연결하는 경로를 구성할 수 있어야 한다.
이다.
주어진 그래프에서 MST 값을 구하는 문제인데, 주의할 점은 턴마다 최소 간선의 값을 제거하면서 MST의 값을 출력하는 문제이다.
Disjoint Set
MST를 구하는 알고리즘은 대표적으로 크루스칼 알고리즘이 있다.
크루스칼 알고리즘을 구현하기 위해선 Disjoint Set이라는 자료구조가 필요하다. Disjoint Set, 서로소 집합 자료구조로 불리는 이 자료구조는 요소가 같은 집합인지 확인하는 역할을 한다.
find()는 요소가 어떤 집합에 속해있는지 반환하는 메소드이고, union()은 요소를 해당 집합으로 합치는 역할을 한다.
static int find(int n) {
if (nodes[n] != n) {
nodes[n] = find(nodes[n]);
}
return nodes[n];
}
static void union(int a, int b) {
a = find(a);
b = find(b);
if (a > b) nodes[a] = b;
else nodes[b] = a;
}
크루스칼 알고리즘
크루스칼 알고리즘으로 MST를 구현한다.
그리디 알고리즘 중 하나로 그래프에서 최소 비용을 가진 트리를 찾는 알고리즘이다. 서로소 집합 자료구조를 활용해 구현하는데, 구현하는 과정은 다음과 같다.
- 간선을 기준으로 오름차순 정렬한다.
- 정렬된 기준으로 서로소 집합 자료구조를 활용하여 사이클 발생 여부를 확인한다.
- find()를 사용해 사이클 발생 여부를 확인할 수 있다. 이미 같은 집합이라면 해당 노드는 사이클이 발생하므로 집합에 포함시키면 안된다.
- 사이클이 발생하지 않는다면 포함시킨 노드의 비용을 누적한다.
- 간선의 수만큼 이 과정을 반복한다.
while (!edge.isEmpty()) {
Node node = edge.poll();
if (find(node.a) != find(node.b)) {
cnt++;
sum += node.c;
union(node.a, node.b);
}
}
위 코드는 서로소 집합 알고리즘을 사용해 같은 집합인지 확인하고, 같은 집합이 아니라면 비용을 누적하고 같은 집합으로 포함시킨다.
코드
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main {
static int[] nodes;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(st.nextToken());
PriorityQueue<Node> pq = new PriorityQueue<>();
for (int i = 1; i <= M; i++) {
st = new StringTokenizer(br.readLine());
int a = Integer.parseInt(st.nextToken());
int b = Integer.parseInt(st.nextToken());
pq.offer(new Node(a, b, i));
}
int idx = 0;
for (int i = 0; i < K; i++) {
nodes = new int[N+1];
for (int j = 1; j <= N; j++) {
nodes[j] = j;
}
int cnt = 0, sum = 0;
PriorityQueue<Node> edge = new PriorityQueue<>(pq);
for (int j = 0; j < idx; j++) {
edge.poll();
}
while (!edge.isEmpty()) {
Node node = edge.poll();
if (find(node.a) != find(node.b)) {
cnt++;
sum += node.c;
union(node.a, node.b);
}
}
idx++;
if (cnt != N-1) System.out.print(0 + " ");
else System.out.print(sum + " ");
}
}
static int find(int n) {
if (nodes[n] != n) {
nodes[n] = find(nodes[n]);
}
return nodes[n];
}
static void union(int a, int b) {
a = find(a);
b = find(b);
if (a > b) nodes[a] = b;
else nodes[b] = a;
}
static class Node implements Comparable<Node> {
int a;
int b;
int c;
public Node(int a, int b, int c) {
this.a = a;
this.b = b;
this.c = c;
}
@Override
public int compareTo(Node o) {
return this.c - o.c;
}
}
}