lib/csrf_plus.ex

defmodule CsrfPlus do
  @moduledoc """
  A CSRF (Cross-Site Request Forgery) protection Plug with accesses storing support.
  """

  @behaviour Plug

  alias CsrfPlus.UserAccess
  alias CsrfPlus.Token
  alias CsrfPlus.UserAccessInfo
  require Logger

  @default_csrf_key "_csrf_token_"
  @non_csrf_request_methods ["GET", "HEAD", "OPTIONS"]

  # Max token age in milliseconds
  @default_token_max_age 24 * 60 * 60 * 1000

  import Plug.Conn

  @doc false
  def init(opts \\ []) do
    csrf_key = Keyword.get(opts, :csrf_key, @default_csrf_key)

    allowed_methods = Keyword.get(opts, :allowed_methods, @non_csrf_request_methods)

    error_mapper = Keyword.get(opts, :error_mapper, CsrfPlus.ErrorMapper)

    raise_exception? = Keyword.get(opts, :raise_exception?, false)

    store =
      :csrf_plus
      |> Application.get_env(CsrfPlus, [])
      |> Keyword.get(:store)

    %{
      csrf_key: csrf_key,
      allowed_methods: allowed_methods,
      raise_exception?: raise_exception?,
      error_mapper: error_mapper,
      store: store
    }
  end

  @doc false
  def call(%Plug.Conn{halted: true} = conn, _opts) do
    conn
  end

  def call(%Plug.Conn{} = conn, opts) do
    allowed_method? = allowed_method?(conn, opts)

    try do
      conn
      |> put_private(
        :plug_csrf_plus,
        %{
          put_session_token: fn conn, access_id, token ->
            csrf_token = opts.csrf_key

            conn =
              if access_id == nil do
                conn
              else
                put_session(conn, :access_id, access_id)
              end

            put_session(conn, csrf_token, token)
          end,
          put_header_token: fn conn, _access_id, signed_token ->
            Plug.Conn.put_resp_header(conn, "x-csrf-token", signed_token)
          end,
          put_store_token: fn conn, access_id, token ->
            store = Map.get(opts, :store)

            if store != nil do
              access =
                if Kernel.function_exported?(store, :conn_to_access, 2) do
                  store.conn_to_access(conn, %{token: token, access_id: access_id})
                else
                  %UserAccess{}
                  |> Map.put(:access_id, access_id)
                  |> Map.put(:token, token)
                end

              store.put_access(access)
            end

            conn
          end
        }
      )
      |> check_token(allowed_method?, opts)
    rescue
      exception ->
        raise_exception? = Map.get(opts, :raise_exception?) == true

        if raise_exception? do
          reraise exception, __STACKTRACE__
        end

        if CsrfPlus.Exception.csrf_plus_exception?(exception) do
          error_mapper = opts.error_mapper

          # Ensure the configured error mapper module is compiled
          # BEWARE as this function call may lead to deadlocks.
          module_compiled = Code.ensure_compiled(error_mapper)

          if match?({:module, _}, module_compiled) &&
               Kernel.function_exported?(error_mapper, :map, 1) do
            {status_code, error} = error_mapper.map(exception)

            # After any CSRF exception, must clear session and return the mapped error and status code
            conn
            |> delete_session(:access_id)
            |> delete_session(opts.csrf_key)
            |> send_resp(status_code, Jason.encode!(error))
            |> halt()
          else
            Logger.debug(
              "The given error_mapper #{inspect(error_mapper)} does not exist or doesn't implements the map/1 function"
            )

            reraise exception, __STACKTRACE__
          end
        else
          reraise exception, __STACKTRACE__
        end
    end
  end

  @doc "Digs into the connection data to make an user access information struct."
  def get_user_info(conn) do
    %UserAccessInfo{
      ip: get_conn_ip(conn),
      user_agent: get_conn_user_agent(conn)
    }
  end

  @doc "The default max age for a token"
  def default_token_max_age do
    @default_token_max_age
  end

  @doc """
  Uses the plug configuration to put the token and its signed version
  into the store, session and `x-csrf-token` header.

  This function uses the functions: `put_session_token/3`, `put_header_token/2` and `put_store_token/3`
  base functions under the hood. So, you can have a look at them for more information about how this function works.

  ## Params
    * `conn` - The connection struct.
    * `opts` - The options.

  ### Options
    The options is a Keyword with the follwing keys:

    * `:access_id` - the id of the access. If none is given CsrfPlus will generate one.
    * `:token_tuple` - a tuple with the token and its signed version in the format `{token, signed_token}`. This option is required.
    * `:exclude` - a list of tokens to exclude. A excluded token will not
    be put into its corresponding store, session or header.

  ### Exclude list
    * `:session` - do not put the session token.
    * `:header` - do not put the header token.
    * `:store` - do not put the store token.

  """
  def put_token(%Plug.Conn{} = conn, opts \\ []) do
    access_id = Keyword.get(opts, :access_id, UUID.uuid4())

    token_tuple =
      Keyword.get(opts, :token_tuple) ||
        raise CsrfPlus.Exception,
              "CsrfPlus.put_token/2 options requires a :token_tuple to be given"

    {token, signed_token} = token_tuple

    excludes = Keyword.get(opts, :exclude, [])
    excludes_session_token? = Enum.find(excludes, nil, fn item -> item == :session end) != nil
    excludes_header_token? = Enum.find(excludes, nil, fn item -> item == :header end) != nil
    excludes_store_token? = Enum.find(excludes, nil, fn item -> item == :store end) != nil

    conn
    |> put_a_token_optional(access_id, token, :session, excludes_session_token?)
    |> put_a_token_optional(access_id, signed_token, :header, excludes_header_token?)
    |> put_a_token_optional(access_id, token, :store, excludes_store_token?)
  end

  @doc """
  Put the token and the given `access_id` in the session. Uses the conn struct to
  determine the needed keys.

  ## Params
  * `conn` - the connection struct.
  * `token` - the CSRF unsigned token.
  * `access_id` - the access id. If none is given no access id is put in the session. Defaults to nil.

  """
  def put_session_token(conn, token, access_id \\ nil) do
    put_a_token(conn, access_id, token, :session)
  end

  @doc """
  Put the token in the header. It uses the conn struct to determine the header name.

  ## Params
  * `conn` - the connection struct.
  * `signed_token` - the signed version of the CSRF token.

  """
  def put_header_token(conn, signed_token) do
    put_a_token(conn, nil, signed_token, :header)
  end

  @doc """
  Put the token in the store. If a `conn_to_access` function is implemented in the
  configured store, that function will be called with the given params to generate
  the `CsrfPlus.UserAccess` to be put into the store. Also, have a look at `CsrfPlus.Store.Behaviour`
  to see more about `conn_to_access` callback.

  ## Params
  * `conn` - the connection struct.
  * `token` - the CSRF unsigned token.
  * `access_id` - the access id. It's required here because a token must be associeted with an identifier.

  """
  def put_store_token(_conn, _token, nil),
    do:
      raise(
        CsrfPlus.Exception,
        "CsrfPlus.put_store_token/3 requires the access_id parameter to be given"
      )

  def put_store_token(conn, token, access_id) do
    put_a_token(conn, access_id, token, :store)
  end

  defp put_a_token_optional(conn, access_id, token, what, false) do
    put_a_token(conn, access_id, token, what)
  end

  defp put_a_token_optional(conn, _access_id, _token, _what, true) do
    conn
  end

  defp put_a_token(%Plug.Conn{private: private} = conn, access_id, token, what) do
    state = Map.get(private, :plug_csrf_plus, %{})
    put_session_token = Map.get(state, :put_session_token, nil)
    put_header_token = Map.get(state, :put_header_token, nil)
    put_store_token = Map.get(state, :put_store_token, nil)

    fun =
      case what do
        :session ->
          put_session_token

        :header ->
          put_header_token

        :store ->
          put_store_token
      end

    if fun == nil do
      raise CsrfPlus.Exception,
            "CsrfPlus.put_token/3 must be called after CsrfPlus is plugged"
    else
      fun.(conn, access_id, token)
    end
  end

  @doc """
  Get a token from the connection session if it exists on the store or generate a new one otherwise.
  This function will try to use the signed token from the header, if it's valid.

  ## Params
    * `conn` - the connection struct.

  ## Returns
    A tuple with the token and its signed version in the format `{token, signed_token}`
  """
  def get_token_tuple(conn) do
    csrf_config = Application.get_env(:csrf_plus, CsrfPlus, [])

    store =
      Keyword.get(csrf_config, :store) ||
        raise CsrfPlus.Exception, CsrfPlus.Exception.StoreException

    access_id = get_session(conn, :access_id)

    access =
      if access_id != nil do
        store.get_access(access_id)
      end

    if access == nil do
      CsrfPlus.Token.generate()
    else
      header_token = get_req_header(conn, "x-csrf-token") |> List.first()

      signed =
        case header_token && CsrfPlus.Token.verify(header_token) do
          {:ok, _signed} ->
            header_token

          {:error, _} ->
            CsrfPlus.Token.sign_token(access.token)

          # header_token is nil
          nil ->
            CsrfPlus.Token.sign_token(access.token)
        end

      {access.token, signed}
    end
  end

  defp allowed_method?(
         %Plug.Conn{method: method},
         %{allowed_methods: allowed_methods} = _opts
       ) do
    method in allowed_methods
  end

  defp check_token(%Plug.Conn{} = conn, true = _allowed_method?, _opts) do
    conn
  end

  defp check_token(%Plug.Conn{body_params: body_params} = conn, false = _allowed_method?, opts) do
    access_id = get_session(conn, :access_id)

    header_token =
      conn
      |> get_req_header("x-csrf-token")
      |> Enum.at(0)

    body_token = Map.get(body_params, "_csrf_token", nil)

    cond do
      is_nil(access_id) ->
        raise CsrfPlus.Exception, {CsrfPlus.Exception.SessionException, :missing_id}

      header_token == nil && body_token == nil ->
        raise CsrfPlus.Exception, CsrfPlus.Exception.SignedException

      true ->
        store = Map.get(opts, :store)
        signed_token = header_token || body_token

        check_token_store(conn, store, {opts, access_id, signed_token})
    end
  end

  defp check_token_store(_conn, nil, _to_check) do
    Logger.debug("CsrfPlus: No token store configured")

    raise CsrfPlus.Exception, CsrfPlus.Exception.StoreException
  end

  defp check_token_store(conn, store, {opts, access_id, signed_token}) do
    csrf_key = Map.get(opts, :csrf_key, nil)
    store_access = store.get_access(access_id)
    store_token = Map.get(store_access || %{}, :token, nil)
    session_token = get_session(conn, csrf_key)

    cond do
      session_token == nil ->
        Logger.debug("Missing token in the request session")

        raise CsrfPlus.Exception, CsrfPlus.Exception.SessionException

      store_access == nil ->
        Logger.debug("The access with id: #{access_id} was not found")

        raise CsrfPlus.Exception, {CsrfPlus.Exception.StoreException, :token_not_found}

      match?(%UserAccess{expired?: true}, store_access) ->
        Logger.debug("The access with id: #{access_id} has expired")

        raise CsrfPlus.Exception, {CsrfPlus.Exception.StoreException, :token_expired}

      session_token != store_token ->
        Logger.debug(
          "Token mismatch session:#{inspect(session_token)} != store:#{inspect(store_token)}"
        )

        raise CsrfPlus.Exception, CsrfPlus.Exception.MismatchException

      true ->
        result = Token.verify(signed_token)
        check_token_store_verified(conn, result, store_token)
    end
  end

  defp check_token_store_verified(conn, {:ok, verified_token}, store_token) do
    if verified_token != store_token do
      Logger.debug(
        "Token mismatch: verified_token:#{inspect(verified_token)} != store_token:#{inspect(store_token)}"
      )

      raise CsrfPlus.Exception, CsrfPlus.Exception.MismatchException
    else
      conn
    end
  end

  defp check_token_store_verified(_conn, {:error, error}, _store_token) do
    Logger.debug("Token validation error: #{inspect(error)}")

    raise CsrfPlus.Exception, error
  end

  defp get_conn_ip(%Plug.Conn{remote_ip: remote_ip, req_headers: req_headers}) do
    [x_real_ip | _] = List.keyfind(req_headers, "x-real-ip", 0, [nil])
    [x_forwarded_for | _] = List.keyfind(req_headers, "x-forwarded-for", 0, [nil])

    case {remote_ip, x_real_ip, x_forwarded_for} do
      {nil, nil, nil} ->
        nil

      {nil, nil, x_forwarded_for} ->
        x_forwarded_for

      {nil, x_real_ip, nil} ->
        x_real_ip

      {remote_ip, nil, nil} ->
        remote_ip
    end
  end

  defp get_conn_user_agent(%Plug.Conn{} = conn) do
    user_agent = get_req_header(conn, "user-agent")

    if Enum.empty?(user_agent) do
      nil
    else
      hd(user_agent)
    end
  end
end