defmodule CsrfPlus do
@moduledoc """
A CSRF (Cross-Site Request Forgery) protection Plug with accesses storing support.
"""
@behaviour Plug
alias CsrfPlus.UserAccess
alias CsrfPlus.Token
alias CsrfPlus.UserAccessInfo
require Logger
@default_csrf_key "_csrf_token_"
@non_csrf_request_methods ["GET", "HEAD", "OPTIONS"]
# Max token age in milliseconds
@default_token_max_age 24 * 60 * 60 * 1000
import Plug.Conn
@doc false
def init(opts \\ []) do
csrf_key = Keyword.get(opts, :csrf_key, @default_csrf_key)
allowed_methods = Keyword.get(opts, :allowed_methods, @non_csrf_request_methods)
error_mapper = Keyword.get(opts, :error_mapper, CsrfPlus.ErrorMapper)
store =
:csrf_plus
|> Application.get_env(CsrfPlus, [])
|> Keyword.get(:store)
%{
csrf_key: csrf_key,
allowed_methods: allowed_methods,
error_mapper: error_mapper,
store: store
}
end
@doc false
def call(%Plug.Conn{halted: true} = conn, _opts) do
conn
end
def call(%Plug.Conn{} = conn, opts) do
allowed_method? = allowed_method?(conn, opts)
try do
conn
|> put_private(
:plug_csrf_plus,
%{
put_session_token: fn conn, token ->
csrf_token = opts.csrf_key
Plug.Conn.put_session(conn, csrf_token, token)
end
}
)
|> check_token(allowed_method?, opts)
rescue
exception ->
if CsrfPlus.Exception.csrf_plus_exception?(exception) do
error_mapper = opts.error_mapper
# Ensure the configured error mapper module is compiled
# BEWARE as this function call may lead to deadlocks.
module_compiled = Code.ensure_compiled(error_mapper)
if match?({:module, _}, module_compiled) &&
Kernel.function_exported?(error_mapper, :map, 1) do
{status_code, error} = error_mapper.map(exception)
# After any CSRF exception, must clear session and return the mapped error and status code
conn
|> delete_session(:access_id)
|> delete_session(opts.csrf_key)
|> send_resp(status_code, Jason.encode!(error))
|> halt()
else
Logger.debug(
"The given error_mapper #{inspect(error_mapper)} does not exist or doesn't implements the map/1 function"
)
reraise exception, __STACKTRACE__
end
else
reraise exception, __STACKTRACE__
end
end
end
@doc "Digs into the connection data to make an user access information struct."
def get_user_info(conn) do
%UserAccessInfo{
ip: get_conn_ip(conn),
user_agent: get_conn_user_agent(conn)
}
end
@doc "The default max age for a token"
def default_token_max_age do
@default_token_max_age
end
@doc "Uses the plug configuration to put a token into the session"
def put_session_token(%Plug.Conn{private: private} = conn, token) do
state = Map.get(private, :plug_csrf_plus, %{})
fun = Map.get(state, :put_session_token, nil)
if fun == nil do
raise CsrfPlus.Exception,
"CsrfPlus.put_session_token/2 must be called after CsrfPlus is plugged"
else
fun.(conn, token)
end
end
defp allowed_method?(
%Plug.Conn{method: method},
%{allowed_methods: allowed_methods} = _opts
) do
method in allowed_methods
end
defp check_token(%Plug.Conn{} = conn, true = _allowed_method?, _opts) do
conn
end
defp check_token(%Plug.Conn{} = conn, false = _allowed_method?, opts) do
access_id = get_session(conn, :access_id)
header_token =
conn
|> get_req_header("x-csrf-token")
|> Enum.at(0)
cond do
is_nil(access_id) ->
raise CsrfPlus.Exception, {CsrfPlus.Exception.SessionException, :missing_id}
header_token == nil ->
raise CsrfPlus.Exception, CsrfPlus.Exception.HeaderException
true ->
store = Map.get(opts, :store)
check_token_store(conn, store, {opts, access_id, header_token})
end
end
defp check_token_store(_conn, nil, _to_check) do
Logger.debug("CsrfPlus: No token store configured")
raise CsrfPlus.Exception, CsrfPlus.Exception.StoreException
end
defp check_token_store(conn, store, {opts, access_id, header_token}) do
csrf_key = Map.get(opts, :csrf_key, nil)
store_access = store.get_access(access_id)
store_token = Map.get(store_access || %{}, :token, nil)
session_token = get_session(conn, csrf_key)
cond do
session_token == nil ->
Logger.debug("Missing token in the request session")
raise CsrfPlus.Exception, CsrfPlus.Exception.SessionException
store_access == nil ->
Logger.debug("The access with id: #{access_id} was not found")
raise CsrfPlus.Exception, {CsrfPlus.Exception.StoreException, :token_not_found}
match?(%UserAccess{expired?: true}, store_access) ->
Logger.debug("The access with id: #{access_id} has expired")
raise CsrfPlus.Exception, {CsrfPlus.Exception.StoreException, :token_expired}
session_token != store_token ->
Logger.debug(
"Token mismatch session:#{inspect(session_token)} != store:#{inspect(store_token)}"
)
raise CsrfPlus.Exception, CsrfPlus.Exception.MismatchException
true ->
result = Token.verify(header_token)
check_token_store_verified(conn, result, store_token)
end
end
defp check_token_store_verified(conn, {:ok, verified_token}, store_token) do
if verified_token != store_token do
Logger.debug(
"Token mismatch: verified_token:#{inspect(verified_token)} != store_token:#{inspect(store_token)}"
)
raise CsrfPlus.Exception, CsrfPlus.Exception.MismatchException
else
conn
end
end
defp check_token_store_verified(_conn, {:error, error}, _store_token) do
Logger.debug("Token validation error: #{inspect(error)}")
raise CsrfPlus.Exception, error
end
defp get_conn_ip(%Plug.Conn{remote_ip: remote_ip, req_headers: req_headers}) do
[x_real_ip | _] = List.keyfind(req_headers, "x-real-ip", 0, [nil])
[x_forwarded_for | _] = List.keyfind(req_headers, "x-forwarded-for", 0, [nil])
case {remote_ip, x_real_ip, x_forwarded_for} do
{nil, nil, nil} ->
nil
{nil, nil, x_forwarded_for} ->
x_forwarded_for
{nil, x_real_ip, nil} ->
x_real_ip
{remote_ip, nil, nil} ->
remote_ip
end
end
defp get_conn_user_agent(%Plug.Conn{} = conn) do
user_agent = get_req_header(conn, "user-agent")
if Enum.empty?(user_agent) do
nil
else
hd(user_agent)
end
end
end