Skip to main content

lib/snapcast/server.ex

defmodule Snapcast.Server do
  @moduledoc """
  The snapcast server: a TCP listener that accepts snapclients, plus the playback
  coordinator. Each connection becomes a `Snapcast.Session` (self-registers on
  `Hello`). `play/3` (re)starts a `Snapcast.Stream` feeding the targeted clients,
  and the transport ops (pause/resume/seek/stop/volume) map onto it.

  One active stream at a time. Client connect/disconnect and playback
  progress/ended are reported to the configured `Snapcast.Listener`, so a host app
  can drive a UI position, advance tracks, and refresh endpoint lists.
  """

  use GenServer

  alias Snapcast.{Clock, Session, Stream}

  require Logger

  @sup Snapcast.SessionSupervisor
  @progress_ms 1000

  def start_link(opts), do: GenServer.start_link(__MODULE__, opts, name: __MODULE__)

  @doc "Play a source to the given client ids, from `position_ms`. opts: :endpoint, :duration_ms, :transport_codec, :format."
  def play(source, client_ids, opts \\ []),
    do: GenServer.call(__MODULE__, {:play, source, client_ids, opts})

  def pause, do: GenServer.call(__MODULE__, :pause)
  def resume, do: GenServer.call(__MODULE__, :resume)

  def seek(position_ms, playback_gen \\ nil),
    do: GenServer.call(__MODULE__, {:seek, position_ms, playback_gen})

  def stop_playback, do: GenServer.call(__MODULE__, :stop_playback)

  def set_volume(client_id, volume),
    do: GenServer.call(__MODULE__, {:set_volume, client_id, volume})

  def clients, do: GenServer.call(__MODULE__, :clients)

  @doc false
  def register(pid, client_id, name),
    do: GenServer.cast(__MODULE__, {:register, pid, client_id, name})

  @doc false
  def unregister(pid), do: GenServer.cast(__MODULE__, {:unregister, pid})

  @impl true
  def init(opts) do
    Clock.init()
    port = Keyword.get(opts, :port, Snapcast.port())
    bind_ip = Keyword.get(opts, :bind_ip, Snapcast.bind_ip())
    listener = Keyword.get(opts, :listener, Snapcast.listener())

    case :gen_tcp.listen(port, [
           :binary,
           packet: :raw,
           active: false,
           ip: bind_ip,
           reuseaddr: true,
           nodelay: true,
           backlog: 16
         ]) do
      {:ok, listen} ->
        Logger.info("Snapcast.Server listening on #{format_ip(bind_ip)}:#{port}")
        # The async accept notification goes to the listen socket's controlling
        # process, so the acceptor must own it (else accept/0 blocks forever).
        acceptor = spawn_link(fn -> receive(do: (:go -> accept_loop(listen))) end)
        :ok = :gen_tcp.controlling_process(listen, acceptor)
        send(acceptor, :go)

        {:ok,
         %{
           listen: listen,
           acceptor: acceptor,
           listener: listener,
           sessions: %{},
           stream: nil,
           session: nil,
           timeline_us: 0
         }}

      {:error, reason} ->
        Logger.error("Snapcast.Server could not listen on #{port}: #{inspect(reason)}")
        :ignore
    end
  end

  # --- playback --------------------------------------------------------------

  @impl true
  def handle_call({:play, source, client_ids, opts}, _from, state) do
    state = reconcile_sessions(state)
    position = opts |> Keyword.get(:position_ms, 0) |> max(0)

    session = %{
      source: source,
      client_ids: client_ids,
      endpoint: Keyword.get(opts, :endpoint),
      duration_ms: Keyword.get(opts, :duration_ms),
      transport_codec:
        Snapcast.normalize_transport_codec(Keyword.get(opts, :transport_codec)) || :pcm,
      format:
        Snapcast.cap_format(
          Snapcast.normalize_format(Keyword.get(opts, :format)) || Snapcast.format()
        ),
      status: :playing,
      started_at: now_ms(),
      started_position_ms: position
    }

    state = start_stream(state, session, position)
    schedule_progress()
    {:reply, :ok, state}
  end

  def handle_call(:pause, _from, %{session: %{} = session} = state) do
    position = estimated_position(session)
    state = stop_stream(state)

    {:reply, :ok,
     %{
       state
       | session: %{
           session
           | status: :paused,
             started_position_ms: position,
             started_at: now_ms()
         }
     }}
  end

  def handle_call(:resume, _from, %{session: %{status: :paused} = session} = state) do
    state =
      start_stream(
        state,
        %{session | status: :playing, started_at: now_ms()},
        session.started_position_ms
      )

    schedule_progress()
    {:reply, :ok, state}
  end

  def handle_call({:seek, position_ms, _gen}, _from, %{session: %{} = session} = state) do
    position = max(position_ms, 0)
    session = %{session | started_position_ms: position, started_at: now_ms(), status: :playing}
    state = start_stream(state, session, position)
    {:reply, :ok, state}
  end

  def handle_call(:stop_playback, _from, state) do
    {:reply, :ok, %{stop_stream(state) | session: nil}}
  end

  def handle_call({:set_volume, client_id, volume}, _from, state) do
    state = reconcile_sessions(state)
    for {pid, %{client_id: ^client_id}} <- state.sessions, do: Session.set_volume(pid, volume)
    {:reply, :ok, state}
  end

  # No active session for pause/resume/seek — ignore.
  def handle_call(op, _from, state)
      when op in [:pause, :resume] or (is_tuple(op) and elem(op, 0) == :seek),
      do: {:reply, :ok, state}

  def handle_call(:clients, _from, state) do
    state = reconcile_sessions(state)
    {:reply, Map.values(state.sessions), state}
  end

  # --- sessions --------------------------------------------------------------

  @impl true
  def handle_cast({:register, pid, client_id, name}, state) do
    Process.monitor(pid)
    state = put_in(state.sessions[pid], %{pid: pid, client_id: client_id, name: name})
    # A target client that (re)connected mid-stream: attach it to the live stream.
    if state.stream && state.session && client_id in state.session.client_ids do
      Session.set_format(pid, state.session.format, state.session.transport_codec)
      Stream.attach(state.stream, pid)
    end

    notify(state, :clients_changed, [])
    {:noreply, state}
  end

  def handle_cast({:unregister, pid}, state) do
    notify(state, :clients_changed, [])
    {:noreply, %{state | sessions: Map.delete(state.sessions, pid)}}
  end

  @impl true
  def handle_info(:progress, %{session: %{status: :playing} = session} = state) do
    position = estimated_position(session)

    if ended?(session, position) do
      emit_ended(state, session)
      {:noreply, %{stop_stream(state) | session: nil}}
    else
      emit_progress(state, session, position)
      schedule_progress()
      {:noreply, state}
    end
  end

  def handle_info(:progress, state), do: {:noreply, state}

  def handle_info({:snap_stream_ended, stream}, %{stream: stream, session: session} = state) do
    emit_ended(state, session)
    {:noreply, %{state | stream: nil, session: nil}}
  end

  def handle_info({:DOWN, _ref, :process, pid, _reason}, state) do
    if Map.has_key?(state.sessions, pid), do: notify(state, :clients_changed, [])
    {:noreply, %{state | sessions: Map.delete(state.sessions, pid)}}
  end

  def handle_info(_msg, state), do: {:noreply, state}

  # --- helpers ---------------------------------------------------------------

  defp start_stream(state, session, position) do
    state = stop_stream(state)
    pids = session_pids(state, session.client_ids)
    # Continue the timeline past the previous stream's high-water so the new audio
    # never overlaps what the client still has buffered (otherwise a seek plays
    # garbage); a long pause makes `now` win, so resume just starts fresh.
    start_us = max(Clock.now_us(), state.timeline_us)

    Enum.each(pids, &Session.set_format(&1, session.format, session.transport_codec))

    {:ok, stream} =
      Stream.start_link(
        source: session.source,
        sessions: pids,
        position_ms: position,
        start_us: start_us,
        owner: self(),
        transport_codec: session.transport_codec,
        format: session.format,
        # Max-wins across the group: the slowest endpoint's per-endpoint floor sets the
        # shared bufferMs so the whole synchronized group stays locked together.
        buffer_ms: Snapcast.effective_buffer_ms(session.format, session.client_ids)
      )

    %{state | stream: stream, session: session}
  end

  defp stop_stream(%{stream: nil} = state), do: state

  defp stop_stream(%{stream: stream} = state) do
    # Capture where the timeline reached before tearing it down, so the next stream
    # can continue from there.
    timeline = high_water(stream, state.timeline_us)
    if Process.alive?(stream), do: Stream.stop(stream)
    %{state | stream: nil, timeline_us: max(timeline, state.timeline_us)}
  rescue
    _ -> %{state | stream: nil}
  catch
    :exit, _ -> %{state | stream: nil}
  end

  defp high_water(stream, fallback) do
    Stream.high_water(stream)
  rescue
    _ -> fallback
  catch
    :exit, _ -> fallback
  end

  defp session_pids(state, client_ids) do
    for {pid, %{client_id: id}} <- state.sessions, id in client_ids, do: pid
  end

  defp reconcile_sessions(state) do
    case session_children() do
      {:ok, children} ->
        sessions =
          children
          |> Enum.flat_map(fn
            {_, pid, :worker, [Session]} when is_pid(pid) ->
              case session_info(pid) do
                {:ok, %{client_id: client_id, name: name}} ->
                  [{pid, %{pid: pid, client_id: client_id, name: name}}]

                _other ->
                  []
              end

            _child ->
              []
          end)
          |> Map.new()

        Enum.each(sessions, fn {pid, %{client_id: client_id}} ->
          unless Map.has_key?(state.sessions, pid), do: Process.monitor(pid)

          if state.stream && state.session && client_id in state.session.client_ids do
            Stream.attach(state.stream, pid)
          end
        end)

        %{state | sessions: sessions}

      :error ->
        state
    end
  end

  defp session_children do
    {:ok, DynamicSupervisor.which_children(@sup)}
  rescue
    _ -> :error
  catch
    :exit, _ -> :error
  end

  defp session_info(pid) do
    Session.client_info(pid)
  rescue
    _ -> :error
  catch
    :exit, _ -> :error
  end

  defp schedule_progress, do: Process.send_after(self(), :progress, @progress_ms)

  defp estimated_position(%{status: :playing, started_at: at, started_position_ms: pos}),
    do: pos + max(now_ms() - at, 0)

  defp estimated_position(%{started_position_ms: pos}), do: pos

  defp ended?(%{duration_ms: d}, position) when is_integer(d) and d > 0, do: position >= d
  defp ended?(_session, _position), do: false

  defp emit_progress(state, %{endpoint: endpoint}, position) when not is_nil(endpoint),
    do: notify(state, :progress, [endpoint, position])

  defp emit_progress(_state, _session, _position), do: :ok

  defp emit_ended(state, %{endpoint: endpoint}) when not is_nil(endpoint),
    do: notify(state, :ended, [endpoint])

  defp emit_ended(_state, _session), do: :ok

  # Invoke the configured listener callback, rescuing any failure so a bad
  # listener can never take the server down.
  defp notify(%{listener: listener}, fun, args) when is_atom(listener) and not is_nil(listener) do
    if function_exported?(listener, fun, length(args)), do: apply(listener, fun, args)
    :ok
  rescue
    _ -> :ok
  catch
    :exit, _ -> :ok
  end

  defp notify(_state, _fun, _args), do: :ok

  defp now_ms, do: System.monotonic_time(:millisecond)

  # --- acceptor --------------------------------------------------------------

  defp accept_loop(listen) do
    case :gen_tcp.accept(listen) do
      {:ok, socket} ->
        {:ok, pid} = DynamicSupervisor.start_child(@sup, {Session, socket: socket})
        :gen_tcp.controlling_process(socket, pid)
        Session.activate(pid)
        accept_loop(listen)

      {:error, :closed} ->
        :ok

      {:error, _reason} ->
        accept_loop(listen)
    end
  end

  defp format_ip({a, b, c, d}), do: Enum.join([a, b, c, d], ".")
  defp format_ip(other), do: inspect(other)
end