defmodule Ueberauth.Strategy.Helpers do
@moduledoc """
Provides helper methods for use within your strategy.
These helpers are provided as a convenience for accessing the options passed
to the specific pipelined strategy, considering the pipelined options and
falling back to defaults.
"""
import Plug.Conn, except: [request_url: 1]
alias Ueberauth.Failure
alias Ueberauth.Failure.Error
@doc """
Provides the name of the strategy or provider name.
This is defined in your configuration as the provider name.
"""
@spec strategy_name(Plug.Conn.t()) :: String.t()
def strategy_name(conn), do: from_private(conn, :strategy_name)
@doc """
The strategy module that is being used for the request.
"""
@spec strategy(Plug.Conn.t()) :: module
def strategy(conn), do: from_private(conn, :strategy)
@doc """
The request path for the strategy to hit.
Requests to this path will trigger the `request_phase` of the strategy.
"""
@spec request_path(Plug.Conn.t()) :: String.t()
def request_path(conn), do: from_private(conn, :request_path)
@doc """
The callback path for the requests strategy.
When a client hits this path, the callback phase will be triggered for the strategy.
"""
@spec callback_path(Plug.Conn.t()) :: String.t()
def callback_path(conn), do: from_private(conn, :callback_path)
@doc """
The full url for the request phase for the requests strategy.
The URL is based on the current requests host and scheme.
"""
@spec request_url(Plug.Conn.t()) :: String.t()
def request_url(conn, query_params \\ []) do
opts = [
scheme: from_private(conn, :request_scheme),
port: from_private(conn, :request_port),
path: request_path(conn),
query_params: query_params
]
full_url(conn, opts)
end
@doc """
The full URL for the callback phase for the requests strategy.
The URL is based on the current requests host and scheme.
"""
@spec callback_url(Plug.Conn.t()) :: String.t()
def callback_url(conn, query_params \\ []) do
if url = from_private(conn, :callback_url) do
url
else
opts = [
scheme: from_private(conn, :callback_scheme),
port: from_private(conn, :callback_port),
path: callback_path(conn),
query_params: callback_params(conn, query_params)
]
full_url(conn, opts)
end
end
@doc """
Build params for callback
This method will filter conn.params with whitelisted params from :callback_params settings
"""
@spec callback_params(Plug.Conn.t()) :: keyword()
def callback_params(conn, query_params \\ []) do
callback_params = from_private(conn, :callback_params) || []
callback_params =
callback_params
|> Enum.map(fn k -> {String.to_atom(k), conn.params[k]} end)
|> Enum.reject(fn {k, v} -> k == "provider" or v == nil end)
Keyword.merge(query_params, callback_params)
end
@doc """
The configured allowed callback http methods.
This will use any supplied options from the configuration, but fallback to the
default options
"""
@spec allowed_callback_methods(Plug.Conn.t()) :: list(String.t())
def allowed_callback_methods(conn), do: from_private(conn, :callback_methods)
@doc """
Is the current request http method one of the allowed callback methods?
"""
@spec allowed_callback_method?(Plug.Conn.t()) :: boolean
def allowed_callback_method?(%{method: method} = conn) do
callback_method =
method
|> to_string
|> String.upcase()
conn
|> allowed_callback_methods
|> Enum.member?(callback_method)
end
@doc """
The full list of options passed to the strategy in the configuration.
"""
@spec options(Plug.Conn.t()) :: Keyword.t()
def options(conn), do: from_private(conn, :options)
@doc """
A helper for constructing error entries on failure.
The `message_key` is intended for use by machines for translations etc.
The message is a human readable error message.
#### Example
error("something_bad", "Something really bad happened")
"""
@spec error(String.t(), String.t()) :: Error.t()
def error(key, message), do: struct(Error, message_key: key, message: message)
@doc """
Sets a failure onto the connection containing a List of errors.
During your callback phase, this should be called to 'fail' the authentication
request and include a collection of errors outlining what the problem is.
Note this changes the conn object and should be part of your returned
connection of the `callback_phase!`.
"""
@spec set_errors!(Plug.Conn.t(), list(Error.t())) :: Plug.Conn.t()
def set_errors!(conn, errors) do
failure =
struct(
Failure,
provider: strategy_name(conn),
strategy: strategy(conn),
errors: map_errors(errors)
)
Plug.Conn.assign(conn, :ueberauth_failure, failure)
end
@doc """
Redirects to a url and halts the plug pipeline.
"""
@spec redirect!(Plug.Conn.t(), String.t()) :: Plug.Conn.t()
def redirect!(conn, url) do
html = Plug.HTML.html_escape(url)
body = "<html><body>You are being <a href=\"#{html}\">redirected</a>.</body></html>"
conn
|> put_resp_header("location", url)
|> send_resp(conn.status || 302, body)
|> halt
end
@doc """
Add state parameter to the `%Plug.Conn{}`.
"""
@spec add_state_param(Plug.Conn.t(), String.t()) :: Plug.Conn.t()
def add_state_param(conn, value) do
Plug.Conn.put_private(conn, :ueberauth_state_param, value)
end
@doc """
Add state parameter to the options.
"""
@spec with_state_param(
keyword(),
Plug.Conn.t()
) :: keyword()
def with_state_param(opts, conn) do
state = conn.private[:ueberauth_state_param]
if is_nil(state) do
opts
else
Keyword.put(opts, :state, state)
end
end
defp from_private(conn, key) do
opts = conn.private[:ueberauth_request_options]
if opts, do: opts[key], else: nil
end
defp full_url(conn, opts) do
scheme =
cond do
scheme = Keyword.get(opts, :scheme) -> scheme
scheme = get_forwarded_proto_header(conn) -> scheme
true -> to_string(conn.scheme)
end
host = get_host_header(conn) || conn.host
[host, port] =
if String.contains?(host, ":"),
do: String.split(host, ":"),
else: [host, to_string(conn.port)]
port = Keyword.get(opts, :port) || normalize_port(scheme, port)
path = Keyword.fetch!(opts, :path)
query =
opts
|> Keyword.get(:query_params, [])
|> encode_query()
%URI{
host: host,
port: port,
path: path,
query: query,
scheme: scheme
}
|> to_string()
end
defp get_forwarded_proto_header(conn) do
conn
|> get_req_header("x-forwarded-proto")
|> List.first()
end
defp get_host_header(conn) do
case get_req_header(conn, "x-forwarded-host") do
[] ->
get_req_header(conn, "host")
|> List.first()
[host | _] ->
host
end
end
defp normalize_port(scheme, "80"), do: URI.default_port(scheme)
defp normalize_port(scheme, nil), do: URI.default_port(scheme)
defp normalize_port(_, port), do: String.to_integer(port)
defp encode_query([]), do: nil
defp encode_query(query_params), do: URI.encode_query(query_params)
defp map_errors(nil), do: []
defp map_errors([]), do: []
defp map_errors(%Error{} = error), do: [error]
defp map_errors(errors), do: Enum.map(errors, &p_error/1)
defp p_error(%Error{} = error), do: error
defp p_error(%{} = error), do: struct(Error, error)
defp p_error(error) when is_list(error), do: struct(Error, error)
end