lib/graph.ex

defmodule Adventurous.Graph do

  defmodule NeighbourCost do
    defstruct vertex: nil, cost: 0
  end

  defp compare_neighbours(%NeighbourCost{cost: c1}, %NeighbourCost{cost: c2}) do
    cond do
      c1 < c2 -> :lt
      c1 > c2 -> :gt
      true -> :eq
    end
  end

  def a_star(start, goal, neighbour_func, vertex_cost) do
    open_set = Prioqueue.new([%NeighbourCost{vertex: start}], cmp_fun: &compare_neighbours/2)

    came_from = Map.new()
    g_score = Map.new()
    |> Map.put(start, 0)

    stream = Stream.repeatedly(fn -> 1 end)

    {_, came_from, score} = Enum.reduce_while(stream, {open_set, came_from, g_score}, fn _, {open_set, came_from, g_score} ->
      case Prioqueue.extract_min(open_set) do
        {:error, _} ->
          {:halt, :error}

        {:ok, {%NeighbourCost{vertex: ^goal}, _}} ->
          {:halt, {nil, came_from, g_score}}

        {:ok, {%NeighbourCost{vertex: current_vertex}, new_open_set}} ->
          neighbour_costs = get_neighbours_costs(current_vertex, g_score, neighbour_func, vertex_cost)
          {new_open_set, new_came_from, new_g_score} = process_neighbours(current_vertex, neighbour_costs, new_open_set, came_from, g_score)
          {:cont, {new_open_set, new_came_from, new_g_score}}
      end

    end)

    {Map.get(score, goal), reconstruct_path(came_from, goal)}
  end

  defp process_neighbours(current_vertex, neighbour_costs, open_set, came_from, g_score) do
    Enum.reduce(neighbour_costs, {open_set, came_from, g_score}, fn neighbour_cost, {open_set, came_from, g_score} ->
      prev_cost = Map.get(g_score, neighbour_cost.vertex, :infinity)

      if prev_cost > neighbour_cost.cost do
        new_open_set = update_open_set(open_set, neighbour_cost, prev_cost)
        new_came_from = Map.put(came_from, neighbour_cost.vertex, {neighbour_cost.cost, current_vertex})
        new_g_score = Map.put(g_score, neighbour_cost.vertex, neighbour_cost.cost)

        {new_open_set, new_came_from, new_g_score}
      else
        {open_set, came_from, g_score}
      end
    end)
  end

  defp get_neighbours_costs(current_vertex, g_score, neighbour_func, cost_func) do
    neighbours = neighbour_func.(current_vertex)
    Enum.map(neighbours, fn neighbour ->

      prev_cost = Map.get(g_score, current_vertex, :infinity)
      cost = cost_func.(neighbour)
      %NeighbourCost{vertex: neighbour, cost: prev_cost + cost}
    end)
  end

  defp update_open_set(open_set, neighbour_cost, prev_cost) do
    new_open_set = if Prioqueue.member?(open_set, %NeighbourCost{vertex: neighbour_cost.vertex, cost: prev_cost}) do
      Prioqueue.to_list(open_set)
      |> Enum.filter(fn nc -> nc.vertex != neighbour_cost.vertex end)
      |> Prioqueue.new(cmp_fun: &compare_neighbours/2)
    else
      open_set
    end

    Prioqueue.insert(new_open_set, neighbour_cost)
  end

  defp reconstruct_path(map, current) do
    case Map.fetch(map, current) do
      :error -> [{0, current}]
      {:ok, {cost, prev}} -> [{cost, current} | reconstruct_path(map, prev)]
    end
  end
end