lib/splitter.ex

defmodule Strom.Splitter do
  use GenServer

  defstruct [:pid, :stream, :partitions, :running, :chunk_every]

  @chunk_every 100

  def start(stream, partitions, opts \\ [])
      when is_function(stream) and is_list(partitions) and is_list(opts) do
    state = %__MODULE__{
      stream: stream,
      running: false,
      partitions: Enum.reduce(partitions, %{}, &Map.put(&2, &1, [])),
      chunk_every: Keyword.get(opts, :chunk_every, @chunk_every)
    }

    {:ok, pid} = GenServer.start_link(__MODULE__, state)
    __state__(pid)
  end

  def init(%__MODULE__{} = splitter) do
    {:ok, %{splitter | pid: self()}}
  end

  def stream(%__MODULE__{partitions: partitions} = splitter) do
    partitions
    |> Map.keys()
    |> Enum.map(fn partition ->
      Stream.resource(
        fn -> GenServer.call(splitter.pid, :run_stream) end,
        fn splitter ->
          case GenServer.call(splitter.pid, {:get_data, partition}) do
            {:ok, data} ->
              {data, splitter}

            {:error, :done} ->
              {:halt, splitter}
          end
        end,
        fn splitter -> splitter end
      )
    end)
  end

  def stop(%__MODULE__{pid: pid}), do: GenServer.call(pid, :stop)

  def __state__(pid) when is_pid(pid), do: GenServer.call(pid, :__state__)

  defp async_run_stream(stream, chunk_every, pid) do
    Task.async(fn ->
      stream
      |> Stream.chunk_every(chunk_every)
      |> Stream.each(fn chunk ->
        data_size = GenServer.call(pid, {:new_data, chunk})
        maybe_wait(data_size, chunk_every)
      end)
      |> Stream.run()

      GenServer.call(pid, :done)
    end)
  end

  defp maybe_wait(data_length, chunk_every) do
    if data_length > 10 * chunk_every do
      div = div(data_length, 10 * chunk_every)
      to_sleep = trunc(:math.pow(2, div))
      Process.sleep(to_sleep)
    end
  end

  def handle_call({:new_data, data}, _from, %__MODULE__{} = splitter) do
    new_partitions =
      Enum.reduce(splitter.partitions, %{}, fn {fun, prev_data}, acc ->
        case Enum.split_with(data, fun) do
          {[], _} ->
            Map.put(acc, fun, prev_data)

          {data, _} ->
            new_data = prev_data ++ data
            Map.put(acc, fun, new_data)
        end
      end)

    data_size =
      Enum.reduce(new_partitions, 0, fn {_key, data}, acc -> acc + length(data) end)

    {:reply, data_size, %{splitter | partitions: new_partitions}}
  end

  def handle_call(:run_stream, _from, %__MODULE__{} = splitter) do
    if splitter.running do
      {:reply, splitter, splitter}
    else
      async_run_stream(splitter.stream, splitter.chunk_every, splitter.pid)
      splitter = %{splitter | running: true}
      {:reply, splitter, splitter}
    end
  end

  def handle_call(:done, _from, %__MODULE__{} = splitter) do
    {:reply, :ok, %{splitter | running: false}}
  end

  def handle_call(
        {:get_data, partition},
        _from,
        %__MODULE__{partitions: partitions, running: running} = splitter
      ) do
    data = Map.get(partitions, partition)

    if length(data) == 0 && !running do
      {:reply, {:error, :done}, splitter}
    else
      {:reply, {:ok, data}, %{splitter | partitions: Map.put(partitions, partition, [])}}
    end
  end

  def handle_call(:stop, _from, %__MODULE__{} = splitter) do
    {:stop, :normal, :ok, %{splitter | running: false}}
  end

  def handle_call(:__state__, _from, splitter), do: {:reply, splitter, splitter}

  def handle_info({_task_ref, :ok}, splitter) do
    # do nothing for now
    {:noreply, splitter}
  end

  def handle_info({:DOWN, _task_ref, :process, _task_pid, :normal}, splitter) do
    # do nothing for now
    {:noreply, splitter}
  end
end