반응형

설명

예시

  • 가중치를 갖는 간선으로 노드들을 연결하여 전체 그래프를 형성하는 것이 목표
  • 가중치 기준으로 간선을 오름차순 정렬
  • 각 간선에 연결된 두 노드가 같은 그룹인지 확인(Union-Find 알고리즘 활용)
  • 같은 그룹이 아니면 해당 간선을 유지시키고, 두 노드를 같은 그룹으로 포함
  • 같은 그룹이면 해당 간선을 버림

코드

public class Solution {
    public int solution(int[][] data) {
        PriorityQueue<int[]> minWeightQueue = new PriorityQueue<>(Comparator.comparing(values -> values[2]));
        minWeightQueue.addAll(Arrays.asList(data));

        int maxValue = Arrays.stream(data).map(values -> Math.max(values[0], values[1])).max(Comparator.comparing(value -> value)).get();
        int distance = 0;
        int[] root = new int[maxValue + 1];

        for (int i = 0; i <= maxValue; i++) {
            root[i] = i;
        }

        while (!minWeightQueue.isEmpty()) {
            int[] values = minWeightQueue.remove();

            if (!isSameUnion(root, values[0], values[1])) {
                union(root, values[0], values[1]);
                distance += values[2];
            }
        }

        return distance;
    }

    private void union(int[] root, int value1, int value2) {
        int x = find(root, value1);
        int y = find(root, value2);

        if (x < y) {
            root[y] = x;
        } else {
            root[x] = y;
        }
    }

    private boolean isSameUnion(int[] root, int value1, int value2) {
        return find(root, value1) == find(root, value2);
    }

    private int find(int[] root, int value) {
        if (root[value] == value) {
            return value;
        }

        return root[value] = find(root, root[value]);
    }
}
class SolutionTest extends Specification {
    @Unroll
    def "#data -> #result"() {
        expect:
        new Solution().solution(data as int[][]) == result as int

        where:
        data                                                                                                                                 | result
        [[1, 7, 12], [1, 4, 28], [1, 2, 67], [1, 5, 17], [2, 4, 24], [2, 5, 62], [3, 5, 20], [3, 6, 37], [4, 7, 13], [5, 6, 45], [5, 7, 73]] | 123
    }
}

Node 객체를 활용한 코드

public class Node<T extends Comparable<T>> {
    public T value;
    public Node<T> root;

    public Node(T value) {
        this.value = value;
        this.root = this;
    }
}
public class Edge<T extends Comparable<T>> implements Comparable<Edge<T>> {
    public Node<T> node1;
    public Node<T> node2;
    public int distance;

    public Edge(Node<T> node1, Node<T> node2, int distance) {
        this.node1 = node1;
        this.node2 = node2;
        this.distance = distance;
    }

    @Override
    public int compareTo(Edge<T> other) {
        return Integer.compare(this.distance, other.distance);
    }
}
public class Graph<T extends Comparable<T>> {
    private Map<T, Node<T>> nodeMap = new HashMap<>();
    private PriorityQueue<Edge<T>> edgeQueue = new PriorityQueue<>();

    public void join(T value1, T value2, int distance) {
        Node<T> node1 = nodeMap.getOrDefault(value1, new Node<T>(value1));
        Node<T> node2 = nodeMap.getOrDefault(value2, new Node<T>(value2));

        nodeMap.put(value1, node1);
        nodeMap.put(value2, node2);

        edgeQueue.add(new Edge<T>(node1, node2, distance));
    }

    public int getDistance() {
        int distance = 0;

        while (!edgeQueue.isEmpty()) {
            Edge<T> edge = edgeQueue.remove();

            if (!isSameUnion(edge.node1, edge.node2)) {
                union(edge.node1, edge.node2);
                distance += edge.distance;
            }
        }

        return distance;
    }

    private boolean isSameUnion(Node<T> node1, Node<T> node2) {
        return getRoot(node1) == getRoot(node2);
    }

    private void union(Node<T> node1, Node<T> node2) {
        Node<T> root1 = getRoot(node1);
        Node<T> root2 = getRoot(node2);

        // 두 노드의 부모 중 작은 값을 갖는 부모로 들어감
        if (root1.value.compareTo(root2.value) > 0) {
            root1.root = root2.root;
        } else {
            root2.root = root1.root;
        }
    }

    private Node<T> getRoot(Node<T> node) {
        if (node.root == node) {
            return node.root;
        }

        return node.root = getRoot(node.root);
    }
}
class SolutionTest extends Specification {
    @Unroll
    def "#graphData -> #result"() {
        expect:
        Graph<Integer> graph = new Graph<>()
        graphData.each { values -> graph.join(values.get(0), values.get(1), values.get(2)) }
        graph.getDistance() == result

        where:
        graphData                                                                                                                            | result
        [[1, 7, 12], [1, 4, 28], [1, 2, 67], [1, 5, 17], [2, 4, 24], [2, 5, 62], [3, 5, 20], [3, 6, 37], [4, 7, 13], [5, 6, 45], [5, 7, 73]] | 123
    }
}
반응형

+ Recent posts