Skip to main content

priv/zog/src/mst.zig

const std = @import("std");

const Edge = struct {
    from: u32,
    to: u32,
    weight: f64,
};

fn compareEdges(context: void, a: Edge, b: Edge) bool {
    _ = context;
    return a.weight < b.weight;
}

pub const UnionFind = struct {
    parent: []u32,
    rank: []u32,

    pub fn init(allocator: std.mem.Allocator, size: usize) !UnionFind {
        const parent = try allocator.alloc(u32, size);
        errdefer allocator.free(parent);
        const rank = try allocator.alloc(u32, size);
        for (0..size) |i| {
            parent[i] = @intCast(i);
            rank[i] = 0;
        }
        return .{ .parent = parent, .rank = rank };
    }

    pub fn deinit(self: *UnionFind, allocator: std.mem.Allocator) void {
        allocator.free(self.parent);
        allocator.free(self.rank);
    }

    pub fn find(self: *UnionFind, i: u32) u32 {
        var root = i;
        while (root != self.parent[root]) {
            root = self.parent[root];
        }
        var curr = i;
        while (curr != root) {
            const next = self.parent[curr];
            self.parent[curr] = root;
            curr = next;
        }
        return root;
    }

    pub fn unionSets(self: *UnionFind, i: u32, j: u32) bool {
        const root_i = self.find(i);
        const root_j = self.find(j);
        if (root_i == root_j) return false;

        if (self.rank[root_i] < self.rank[root_j]) {
            self.parent[root_i] = root_j;
        } else if (self.rank[root_i] > self.rank[root_j]) {
            self.parent[root_j] = root_i;
        } else {
            self.parent[root_j] = root_i;
            self.rank[root_i] += 1;
        }
        return true;
    }
};

pub const MstResult = struct {
    from: []u32,
    to: []u32,
    weight: []f64,
};

/// Computes the Minimum Spanning Tree (MST) using Kruskal's algorithm natively.
pub fn kruskal(allocator: std.mem.Allocator, graph: anytype) !MstResult {
    const V = graph.nodeCount();
    const E = graph.edgeCount();

    if (V == 0 or E == 0) {
        return .{
            .from = &[_]u32{},
            .to = &[_]u32{},
            .weight = &[_]f64{},
        };
    }

    var edges_list = std.ArrayList(Edge).empty;
    defer edges_list.deinit(allocator);

    var node_it = graph.nodeIds();
    while (node_it.next()) |u| {
        var succ_it = graph.successors(u);
        while (succ_it.next()) |edge| {
            const v = edge.to;
            if (u < v) {
                try edges_list.append(allocator, .{
                    .from = u,
                    .to = v,
                    .weight = edge.data,
                });
            }
        }
    }

    std.mem.sort(Edge, edges_list.items, {}, compareEdges);

    var uf = try UnionFind.init(allocator, V);
    defer uf.deinit(allocator);

    var mst_from = std.ArrayList(u32).empty;
    errdefer mst_from.deinit(allocator);
    var mst_to = std.ArrayList(u32).empty;
    errdefer mst_to.deinit(allocator);
    var mst_weight = std.ArrayList(f64).empty;
    errdefer mst_weight.deinit(allocator);

    for (edges_list.items) |edge| {
        if (uf.unionSets(edge.from, edge.to)) {
            try mst_from.append(allocator, edge.from);
            try mst_to.append(allocator, edge.to);
            try mst_weight.append(allocator, edge.weight);
            if (mst_from.items.len == V - 1) break;
        }
    }

    return .{
        .from = try mst_from.toOwnedSlice(allocator),
        .to = try mst_to.toOwnedSlice(allocator),
        .weight = try mst_weight.toOwnedSlice(allocator),
    };
}

test "kruskal: simple MST" {
    const allocator = std.testing.allocator;
    const AG = @import("models/array_graph.zig").ArrayGraph;

    var g = AG(void, f64).init(allocator);
    defer g.deinit();

    _ = try g.addNode({});
    _ = try g.addNode({});
    _ = try g.addNode({});

    _ = try g.addEdge(0, 1, 10.0); _ = try g.addEdge(1, 0, 10.0);
    _ = try g.addEdge(1, 2, 5.0);  _ = try g.addEdge(2, 1, 5.0);
    _ = try g.addEdge(0, 2, 20.0); _ = try g.addEdge(2, 0, 20.0);

    const res = try kruskal(allocator, g);
    defer {
        allocator.free(res.from);
        allocator.free(res.to);
        allocator.free(res.weight);
    }

    try std.testing.expectEqual(@as(usize, 2), res.from.len);
    // Expected edges: 0-1 (weight 10), 1-2 (weight 5)
    var total_weight: f64 = 0.0;
    for (res.weight) |w| {
        total_weight += w;
    }
    try std.testing.expectEqual(@as(f64, 15.0), total_weight);
}