lib/exagon/zeroconf/mdns/server.ex

# Copyright 2022 Exagon team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

defmodule Exagon.Zeroconf.Mdns.Server do
  @moduledoc """
  MDNS server implementation
  See https://www.rfc-editor.org/rfc/rfc6762.html for specifications

  Reference documentation:
   - https://www.ques10.com/p/10908/explain-dns-message-format-with-neat-diagram-1/

  """
  alias Phoenix.PubSub
  use GenServer
  require Logger

  defmodule Record do
    @moduledoc false
    defstruct dns_resource: nil, ttl_timers: []
    @type t :: %__MODULE__{dns_resource: DNS.Record.t(), ttl_timers: [reference()]}
  end

  defmodule State do
    @moduledoc false
    defstruct udp4: nil,
              udp6: nil,
              hostname: nil,
              domain_name: nil,
              fqdn: nil,
              probe: %{},
              records: %{},
              authoritative_records: %{}

    @type t :: %__MODULE__{
            udp4: {:inet.socket(), :inet.ip_address(), :inet.port_number()},
            udp6: {:inet.socket(), :inet.ip_address(), :inet.port_number()},
            hostname: String.t(),
            domain_name: String.t(),
            fqdn: String.t(),
            probe: map(),
            records: map,
            authoritative_records: map
          }
  end

  @query_packet %DNS.Record{
    header: %DNS.Header{
      aa: false,
      qr: false,
      opcode: 0,
      rcode: 0
    },
    qdlist: []
  }
  @response_packet %DNS.Record{
    header: %DNS.Header{
      aa: true,
      qr: true,
      opcode: 0,
      rcode: 0
    },
    arlist: [],
    anlist: []
  }
  @ttl 120

  @topic_name "Exagon.Zeroconf.Mdns"

  @doc """
  Start a new MDNS server.

  The server is started with a set of socket parameter compatible with MDNS specification.
  Two sockets are opened on IPv4 (`224.0.0.51`) and IPv6 (`ff02::fb`) with port `5353`.
  Parameters can be overriden with options from ``:gen_udp.open/2`` (https://www.erlang.org/doc/man/gen_udp.html#open-2)
  """
  def start_link(args) do
    GenServer.start_link(__MODULE__, args, name: __MODULE__)
  end

  def init(args) do
    mdns_conf = Application.get_env(:exagon_zeroconf, :mdns)
    init_args = args ++ mdns_conf
    Logger.debug("Init args: #{inspect(init_args)}")

    hostname = init_args[:hostname] || :inet.gethostname() |> elem(1) |> to_string
    domain_name = init_args[:domain_name] || ".local"

    ipv6_interface = init_args[:ipv6_interface] || :inet.parse_address('::') |> elem(1)

    {:ok, ipv6_address} = :inet.parse_address('ff02::fb')

    ipv4_address = {224, 0, 0, 251}
    ipv4_interface = init_args[:ipv4_interface] || {0, 0, 0, 0}

    udp_options = [
      :binary,
      active: true,
      reuseaddr: true,
      multicast_loop: true,
      multicast_ttl: 1
    ]

    udp4_options =
      [
        :inet,
        multicast_if: ipv4_interface,
        add_membership: {ipv4_address, ipv4_interface}
      ] ++ udp_options

    udp6_options =
      [
        :inet6,
        multicast_if: ipv6_interface,
        add_membership: {ipv6_address, ipv6_interface}
      ] ++ udp_options

    udp4 =
      with use_ipv4 when use_ipv4 == true <- mdns_conf[:use_ipv4],
           {:ok, udp4} <- :gen_udp.open(init_args[:port], udp4_options) do
        Logger.debug(
          "Start listening at #{:inet.ntoa(udp4_options[:multicast_if])}:#{init_args[:port]}"
        )

        {udp4, ipv4_address, init_args[:port]}
      else
        {:error, reason} ->
          Logger.warning(
            "Failed listening at #{:inet.ntoa(udp4_options[:multicast_if])}:#{init_args[:port]}: #{reason}"
          )

          {nil, ipv4_address, init_args[:port]}

        _ ->
          {nil, ipv4_address, init_args[:port]}
      end

    # IPv6 UDP is currently (OTP 25.2) not working. See https://github.com/erlang/otp/issues/5789 for an update
    udp6 =
      with use_ipv6 when use_ipv6 == true <- mdns_conf[:use_ipv6],
           {:ok, udp6} <- :gen_udp.open(init_args[:port], udp6_options) do
        Logger.debug(
          "Start listening at #{:inet.ntoa(udp6_options[:multicast_if])} on port #{init_args[:port]}"
        )

        {udp6, ipv6_address, udp6_options[:port]}
      else
        {:error, reason} ->
          Logger.warning(
            "Failed listening at #{:inet.ntoa(udp6_options[:multicast_if])} on port #{init_args[:port]}: #{reason}"
          )

          {nil, ipv6_address, init_args[:port]}

        _ ->
          {nil, ipv6_address, init_args[:port]}
      end

    state =
      startup(%State{
        udp4: udp4,
        udp6: udp6,
        hostname: hostname,
        domain_name: domain_name,
        fqdn: hostname <> domain_name
      })

    {:ok, state}
  end

  defp startup(state) do
    probe(state.hostname, state)
  end

  defp probe(hostname, state) do
    timer = Process.send_after(__MODULE__, {:do_probe, 1}, :rand.uniform(250))
    %State{state | probe: %{fqdn: hostname <> state.domain_name, timer: timer}}
  end

  defp cancel_probe(state) do
    if not is_nil(state.probe) && not is_nil(state.probe.timer) do
      Process.cancel_timer(state.probe.timer)
    end

    Logger.debug("Cancelled probing timer")

    %State{state | probe: nil}
  end

  defp announce(state) do
    ar4 =
      case state.udp4 do
        {_, ip, _} ->
          %DNS.Resource{
            class: :in,
            type: :a,
            domain: to_charlist(state.fqdn),
            ttl: @ttl,
            data: ip
          }

        nil ->
          nil
      end

    ar6 =
      case state.udp6 do
        {_, ip, _} ->
          %DNS.Resource{
            class: :in,
            type: :aaaa,
            domain: to_charlist(state.fqdn),
            ttl: @ttl,
            data: ip
          }

        nil ->
          nil
      end

    send_answers(nil, [ar4, ar6] |> Enum.filter(fn x -> not is_nil(x) end), false, state)
    cancel_probe(state)
  end

  @doc """
  Register a process for receiving notifications on added, removed or changed DNS record

  Notifications use [Phoenix PubSub](https://hexdocs.pm/phoenix_pubsub/Phoenix.PubSub.html).

  Listeners will receive:
   - `{:record_removed, old_record}`, when a DNS record is removed after timeout. `old_record` contains the removed record data
   - `{:record_changed, old_resource, resource}`, when a DNS record data are updated (before timeout refresh for example)
   - `{:record_added, resource}`, when a DNS record is added
  """
  @spec subscribe() :: :ok | {:error, {:already_registered, pid}}
  def subscribe(), do: PubSub.subscribe(:zeroconf_pubsub, @topic_name)

  @doc """
  Add a list of `DNS.Resource` to the list of this server authoritative records

  Authoritative records are broadcasted on the MDNS network.
  """
  @spec add_resources(DNS.Resource.t()) :: :ok
  def add_resources(resources) do
    GenServer.cast(__MODULE__, {:add_resource, resources})
  end

  def handle_cast({:add_resource, resources}, state) do
    records =
      Enum.map(resources, fn r ->
        %DNS.Resource{r | ttl: @ttl, class: :in, func: true}
      end)

    send_answers(records, nil, true, state)

    records =
      Enum.map(records, fn r ->
        {resource_key(r), r}
      end)
      |> Map.new()
      |> Map.merge(state.authoritative_records)

    state = %State{state | authoritative_records: records}
    {:noreply, state}
  end

  def handle_info({:do_probe, count}, state) do
    if count <= 3 do
      case Map.get(state.probe, :fqdn) do
        nil ->
          Logger.warn("Invalid state for probing (probe.hostname == nil)")

        fqdn ->
          Logger.debug("Probe ##{count} for #{fqdn}")

          query = %DNS.Query{
            class: 1,
            type: :any,
            domain: to_charlist(fqdn)
          }

          send_query(query, state)
          Process.send_after(__MODULE__, {:do_probe, count + 1}, 250)
      end

      {:noreply, state}
    else
      {:noreply, announce(state)}
    end
  end

  def handle_info({:udp, _socket, src_ip, src_port, packet}, state) do
    {:noreply, handle_packet(src_ip, src_port, packet, state)}
  end

  def handle_info({:ttl_timeout1, record_key}, state) do
    with old_record when not is_nil(old_record) <- Map.get(state.records, record_key),
         old_resource when not is_nil(old_resource) <- Map.get(old_record, :dns_resource) do
      Logger.debug("TTL timeout 1 received for #{record_key}")

      %DNS.Query{
        class: :in,
        type: old_resource.type,
        domain: old_resource.domain
      }
      |> send_query(state)
    end

    {:noreply, state}
  end

  def handle_info({:ttl_timeout2, record_key}, state) do
    state =
      with old_record when not is_nil(old_record) <- Map.get(state.records, record_key) do
        Logger.debug("TTL timeout 2 received for #{record_key}")
        Logger.debug("Record #{record_key} dropped")

        for timer <- old_record.ttl_timers do
          Process.cancel_timer(timer)
        end

        PubSub.broadcast!(
          :zeroconf_pubsub,
          @topic_name,
          {:record_removed, old_record}
        )

        %State{state | records: Map.drop(state.records, [record_key])}
      else
        _ -> state
      end

    {:noreply, state}
  end

  def handle_call(:dump, _from, state) do
    {:reply, Map.values(state.records), state}
  end

  defp handle_packet(_src_ip, _src_port, packet, state) do
    record = DNS.Record.decode(packet)

    case record.header.qr do
      true -> handle_response(record, state)
      false -> handle_query(record, state)
    end
  end

  defp handle_response(record, state) do
    resources =
      Enum.map(record.anlist ++ record.arlist, fn record ->
        case record do
          %DNS.Resource{} ->
            {resource_key(record), record}

          _ ->
            nil
        end
      end)
      |> Enum.filter(fn elem -> not is_nil(elem) end)
      |> Enum.map(fn {key, resource} ->
        resource = fetch_record_diff(key, resource, state)
        timer1 = Process.send_after(__MODULE__, {:ttl_timeout1, key}, resource.ttl * 800)
        timer2 = Process.send_after(__MODULE__, {:ttl_timeout2, key}, resource.ttl * 1000)
        Logger.debug("TTL timeout set to #{resource.ttl * 1000} milliseconds for #{key}")
        %{key => %Record{dns_resource: resource, ttl_timers: [timer1, timer2]}}
      end)
      |> Enum.reduce(fn elem, acc -> Map.merge(acc, elem) end)

    # If we are currently probing, check if response contains the probe hostname
    # If yes, cancel probing and run it again later
    prob_state =
      if not is_nil(state.probe) do
        fqdn = Map.get(state.probe, :fqdn)

        if Map.values(resources)
           |> Enum.map(fn record -> record.dns_resource end)
           |> Enum.filter(fn r -> r.domain == fqdn end)
           |> Enum.empty?() do
          state
        else
          Logger.debug("Received an answer containing FQDN while probing #[fqdn}. Cancel probing")
          cancel_probe(state)
        end
      else
        state
      end

    %State{Map.merge(state, prob_state) | records: Map.merge(state.records, resources)}
  end

  defp handle_query(record, state) do
    responses =
      Enum.map(record.qdlist, fn query ->
        Map.get(state.authoritative_records, resource_key(query))
      end)
      |> Enum.filter(fn elem -> not is_nil(elem) end)

    if length(responses) > 0 do
      responses
      |> send_answers(nil, true, state)
    end

    state
  end

  @doc """
  Dump list of `DNS.Resource` received by the server. Each instance contains DNS record information.
  """
  def dump() do
    GenServer.call(__MODULE__, :dump)
    |> Enum.map(fn r -> r.dns_resource end)
  end

  @doc """
  Pretty print of DNS records known by the server.
  """
  def pretty_dump() do
    IO.puts(
      "#{String.pad_trailing("CLASS", 3)}\t#{String.pad_trailing("TYPE", 5)}\t#{String.pad_trailing("TTL", 5)}\t#{String.pad_trailing("DOMAIN", 100)}"
    )

    resources = dump()

    for resource <- resources do
      IO.puts(
        "#{String.pad_trailing(to_string(resource.class), 3)}\t#{String.pad_trailing(to_string(resource.type), 5)}\t#{String.pad_trailing(to_string(resource.ttl), 5)}\t#{String.pad_trailing(to_string(resource.domain), 100)}"
      )
    end

    IO.puts("#{length(resources)} records")

    :ok
  end

  defp resource_key(%DNS.Resource{class: class, type: type, domain: domain}) do
    "#{class}.#{type}.#{domain}"
  end

  defp resource_key(%DNS.Query{class: class, type: type, domain: domain}) do
    "#{class}.#{type}.#{domain}"
  end

  defp fetch_record_diff(key, %DNS.Resource{ttl: ttl} = resource, state) do
    ttl = if ttl == 0, do: 1, else: ttl
    resource = %DNS.Resource{resource | ttl: ttl}

    case Map.fetch(state.records, key) do
      # Resource already exists in recorded resources
      # cancel TTL time before merging
      {:ok, %Record{dns_resource: old_resource, ttl_timers: timers}} ->
        Logger.debug("Record update: #{key}")

        PubSub.broadcast!(
          :zeroconf_pubsub,
          @topic_name,
          {:record_changed, old_resource, resource}
        )

        for timer <- timers do
          Process.cancel_timer(timer)
        end

        resource

      :error ->
        Logger.debug("New record added: #{key}")
        PubSub.broadcast!(:zeroconf_pubsub, @topic_name, {:record_added, resource})
        resource
    end
  end

  defp send_query(%DNS.Query{} = query, state) do
    send_queries([query], state)
  end

  defp send_queries(queries, state) do
    header = %DNS.Header{@query_packet.header | id: 0}
    packet = %DNS.Record{@query_packet | header: header, qdlist: queries}

    for {socket, ip, port} when not is_nil(socket) <- [state.udp4, state.udp6] do
      :gen_udp.send(socket, ip, port, DNS.Record.encode(packet))
    end
  end

  defp send_answers(anList, arList, authority, state) do
    anList = anList || []
    arList = arList || []
    header = %DNS.Header{@response_packet.header | id: 0, aa: authority}
    packet = %DNS.Record{@response_packet | header: header, anlist: anList, arlist: arList}

    for {socket, ip, port} when not is_nil(socket) <- [state.udp4, state.udp6] do
      :gen_udp.send(socket, ip, port, DNS.Record.encode(packet))
    end
  end
end