defmodule Phoenix.Socket.Transport do
@moduledoc """
Outlines the Socket <-> Transport communication.
This module specifies a behaviour that all sockets must implement.
`Phoenix.Socket` is just one possible implementation of a socket
that multiplexes events over multiple channels. Developers can
implement their own sockets as long as they implement the behaviour
outlined here.
Developers interested in implementing custom transports must invoke
the socket API defined in this module. This module also provides
many conveniences that invokes the underlying socket API to make
it easier to build custom transports.
## Booting sockets
Whenever your endpoint starts, it will automatically invoke the
`child_spec/1` on each listed socket and start that specification
under the endpoint supervisor.
Since the socket supervision tree is started by the endpoint,
any custom transport must be started after the endpoint in a
supervision tree.
## Operating sockets
Sockets are operated by a transport. When a transport is defined,
it usually receives a socket module and the module will be invoked
when certain events happen at the transport level.
Whenever the transport receives a new connection, it should invoke
the `c:connect/1` callback with a map of metadata. Different sockets
may require different metadatas.
If the connection is accepted, the transport can move the connection
to another process, if so desires, or keep using the same process. The
process responsible for managing the socket should then call `c:init/1`.
For each message received from the client, the transport must call
`c:handle_in/2` on the socket. For each informational message the
transport receives, it should call `c:handle_info/2` on the socket.
Transports can optionally implement `c:handle_control/2` for handling
control frames such as `:ping` and `:pong`.
On termination, `c:terminate/2` must be called. A special atom with
reason `:closed` can be used to specify that the client terminated
the connection.
## Example
Here is a simple echo socket implementation:
defmodule EchoSocket do
@behaviour Phoenix.Socket.Transport
def child_spec(opts) do
# We won't spawn any process, so let's return a dummy task
%{id: __MODULE__, start: {Task, :start_link, [fn -> :ok end]}, restart: :transient}
end
def connect(state) do
# Callback to retrieve relevant data from the connection.
# The map contains options, params, transport and endpoint keys.
{:ok, state}
end
def init(state) do
# Now we are effectively inside the process that maintains the socket.
{:ok, state}
end
def handle_in({text, _opts}, state) do
{:reply, :ok, {:text, text}, state}
end
def handle_info(_, state) do
{:ok, state}
end
def terminate(_reason, _state) do
:ok
end
end
It can be mounted in your endpoint like any other socket:
socket "/socket", EchoSocket, websocket: true, longpoll: true
You can now interact with the socket under `/socket/websocket`
and `/socket/longpoll`.
## Security
This module also provides functions to enable a secure environment
on transports that, at some point, have access to a `Plug.Conn`.
The functionality provided by this module helps in performing "origin"
header checks and ensuring only SSL connections are allowed.
"""
@type state :: term()
@doc """
Returns a child specification for socket management.
This is invoked only once per socket regardless of
the number of transports and should be responsible
for setting up any process structure used exclusively
by the socket regardless of transports.
Each socket connection is started by the transport
and the process that controls the socket likely
belongs to the transport. However, some sockets spawn
new processes, such as `Phoenix.Socket` which spawns
channels, and this gives the ability to start a
supervision tree associated to the socket.
It receives the socket options from the endpoint,
for example:
socket "/my_app", MyApp.Socket, shutdown: 5000
means `child_spec([shutdown: 5000])` will be invoked.
"""
@callback child_spec(keyword) :: :supervisor.child_spec
@doc """
Connects to the socket.
The transport passes a map of metadata and the socket
returns `{:ok, state}` or `:error`. The state must be
stored by the transport and returned in all future
operations.
This function is used for authorization purposes and it
may be invoked outside of the process that effectively
runs the socket.
In the default `Phoenix.Socket` implementation, the
metadata expects the following keys:
* `:endpoint` - the application endpoint
* `:transport` - the transport name
* `:params` - the connection parameters
* `:options` - a keyword list of transport options, often
given by developers when configuring the transport.
It must include a `:serializer` field with the list of
serializers and their requirements
"""
@callback connect(transport_info :: map) :: {:ok, state} | :error
@doc """
Initializes the socket state.
This must be executed from the process that will effectively
operate the socket.
"""
@callback init(state) :: {:ok, state}
@doc """
Handles incoming socket messages.
The message is represented as `{payload, options}`. It must
return one of:
* `{:ok, state}` - continues the socket with no reply
* `{:reply, status, reply, state}` - continues the socket with reply
* `{:stop, reason, state}` - stops the socket
The `reply` is a tuple contain an `opcode` atom and a message that can
be any term. The built-in websocket transport supports both `:text` and
`:binary` opcode and the message must be always iodata. Long polling only
supports text opcode.
"""
@callback handle_in({message :: term, opts :: keyword}, state) ::
{:ok, state}
| {:reply, :ok | :error, {opcode :: atom, message :: term}, state}
| {:stop, reason :: term, state}
@doc """
Handles incoming control frames.
The message is represented as `{payload, options}`. It must
return one of:
* `{:ok, state}` - continues the socket with no reply
* `{:reply, status, reply, state}` - continues the socket with reply
* `{:stop, reason, state}` - stops the socket
Control frames only supported when using websockets.
The `options` contains an `opcode` key, this will be either `:ping` or
`:pong`.
If a control frame doesn't have a payload, then the payload value
will be `nil`.
"""
@callback handle_control({message :: term, opts :: keyword}, state) ::
{:ok, state}
| {:reply, :ok | :error, {opcode :: atom, message :: term}, state}
| {:stop, reason :: term, state}
@doc """
Handles info messages.
The message is a term. It must return one of:
* `{:ok, state}` - continues the socket with no reply
* `{:push, reply, state}` - continues the socket with reply
* `{:stop, reason, state}` - stops the socket
The `reply` is a tuple contain an `opcode` atom and a message that can
be any term. The built-in websocket transport supports both `:text` and
`:binary` opcode and the message must be always iodata. Long polling only
supports text opcode.
"""
@callback handle_info(message :: term, state) ::
{:ok, state}
| {:push, {opcode :: atom, message :: term}, state}
| {:stop, reason :: term, state}
@doc """
Invoked on termination.
If `reason` is `:closed`, it means the client closed the socket. This is
considered a `:normal` exit signal, so linked process will not automatically
exit. See `Process.exit/2` for more details on exit signals.
"""
@callback terminate(reason :: term, state) :: :ok
@optional_callbacks handle_control: 2
require Logger
@doc false
def load_config(true, module),
do: module.default_config()
def load_config(config, module),
do: module.default_config() |> Keyword.merge(config) |> load_config()
@doc false
def load_config(config) do
{connect_info, config} = Keyword.pop(config, :connect_info, [])
connect_info =
Enum.map(connect_info, fn
key when key in [:peer_data, :trace_context_headers, :uri, :user_agent, :x_headers] ->
key
{:session, session} ->
{:session, init_session(session)}
{_, _} = pair ->
pair
other ->
raise ArgumentError,
":connect_info keys are expected to be one of :peer_data, :trace_context_headers, :x_headers, :uri, or {:session, config}, " <>
"optionally followed by custom keyword pairs, got: #{inspect(other)}"
end)
[connect_info: connect_info] ++ config
end
defp init_session(session_config) when is_list(session_config) do
key = Keyword.fetch!(session_config, :key)
store = Plug.Session.Store.get(Keyword.fetch!(session_config, :store))
init = store.init(Keyword.drop(session_config, [:store, :key]))
{key, store, init}
end
defp init_session({_, _, _} = mfa) do
{:mfa, mfa}
end
@doc """
Runs the code reloader if enabled.
"""
def code_reload(conn, endpoint, opts) do
reload? = Keyword.get(opts, :code_reloader, endpoint.config(:code_reloader))
reload? && Phoenix.CodeReloader.reload(endpoint)
conn
end
@doc """
Forces SSL in the socket connection.
Uses the endpoint configuration to decide so. It is a
noop if the connection has been halted.
"""
def force_ssl(%{halted: true} = conn, _socket, _endpoint, _opts) do
conn
end
def force_ssl(conn, socket, endpoint, opts) do
if force_ssl = force_ssl_config(socket, endpoint, opts) do
Plug.SSL.call(conn, force_ssl)
else
conn
end
end
defp force_ssl_config(socket, endpoint, opts) do
Phoenix.Config.cache(endpoint, {:force_ssl, socket}, fn _ ->
opts =
if force_ssl = Keyword.get(opts, :force_ssl, endpoint.config(:force_ssl)) do
force_ssl
|> Keyword.put_new(:host, {endpoint, :host, []})
|> Plug.SSL.init()
end
{:cache, opts}
end)
end
@doc """
Logs the transport request.
Available for transports that generate a connection.
"""
def transport_log(conn, level) do
if level do
Plug.Logger.call(conn, Plug.Logger.init(log: level))
else
conn
end
end
@doc """
Checks the origin request header against the list of allowed origins.
Should be called by transports before connecting when appropriate.
If the origin header matches the allowed origins, no origin header was
sent or no origin was configured, it will return the given connection.
Otherwise a 403 Forbidden response will be sent and the connection halted.
It is a noop if the connection has been halted.
"""
def check_origin(conn, handler, endpoint, opts, sender \\ &Plug.Conn.send_resp/1)
def check_origin(%Plug.Conn{halted: true} = conn, _handler, _endpoint, _opts, _sender),
do: conn
def check_origin(conn, handler, endpoint, opts, sender) do
import Plug.Conn
origin = conn |> get_req_header("origin") |> List.first()
check_origin = check_origin_config(handler, endpoint, opts)
cond do
is_nil(origin) or check_origin == false ->
conn
origin_allowed?(check_origin, URI.parse(origin), endpoint, conn) ->
conn
true ->
Logger.error """
Could not check origin for Phoenix.Socket transport.
Origin of the request: #{origin}
This happens when you are attempting a socket connection to
a different host than the one configured in your config/
files. For example, in development the host is configured
to "localhost" but you may be trying to access it from
"127.0.0.1". To fix this issue, you may either:
1. update [url: [host: ...]] to your actual host in the
config file for your current environment (recommended)
2. pass the :check_origin option when configuring your
endpoint or when configuring the transport in your
UserSocket module, explicitly outlining which origins
are allowed:
check_origin: ["https://example.com",
"//another.com:888", "//other.com"]
"""
resp(conn, :forbidden, "")
|> sender.()
|> halt()
end
end
@doc """
Checks the Websocket subprotocols request header against the allowed subprotocols.
Should be called by transports before connecting when appropriate.
If the sec-websocket-protocol header matches the allowed subprotocols,
it will put sec-websocket-protocol response header and return the given connection.
If no sec-websocket-protocol header was sent it will return the given connection.
Otherwise a 403 Forbidden response will be sent and the connection halted.
It is a noop if the connection has been halted.
"""
def check_subprotocols(conn, subprotocols)
def check_subprotocols(%Plug.Conn{halted: true} = conn, _subprotocols), do: conn
def check_subprotocols(conn, nil), do: conn
def check_subprotocols(conn, subprotocols) when is_list(subprotocols) do
case Plug.Conn.get_req_header(conn, "sec-websocket-protocol") do
[] ->
conn
[subprotocols_header | _] ->
request_subprotocols = subprotocols_header |> Plug.Conn.Utils.list()
subprotocol = Enum.find(subprotocols, fn elem -> Enum.find(request_subprotocols, &(&1 == elem)) end)
if subprotocol do
Plug.Conn.put_resp_header(conn, "sec-websocket-protocol", subprotocol)
else
subprotocols_error_response(conn, subprotocols)
end
end
end
def check_subprotocols(conn, subprotocols), do: subprotocols_error_response(conn, subprotocols)
@doc """
Extracts connection information from `conn` and returns a map.
Keys are retrieved from the optional transport option `:connect_info`.
This functionality is transport specific. Please refer to your transports'
documentation for more information.
The supported keys are:
* `:peer_data` - the result of `Plug.Conn.get_peer_data/1`
* `:trace_context_headers` - a list of all trace context headers
* `:x_headers` - a list of all request headers that have an "x-" prefix
* `:uri` - a `%URI{}` derived from the conn
* `:user_agent` - the value of the "user-agent" request header
"""
def connect_info(conn, endpoint, keys) do
for key <- keys, into: %{} do
case key do
:peer_data ->
{:peer_data, Plug.Conn.get_peer_data(conn)}
:trace_context_headers ->
{:trace_context_headers, fetch_trace_context_headers(conn)}
:x_headers ->
{:x_headers, fetch_x_headers(conn)}
:uri ->
{:uri, fetch_uri(conn)}
:user_agent ->
{:user_agent, fetch_user_agent(conn)}
{:session, session} ->
{:session, connect_session(conn, endpoint, session)}
{key, val} ->
{key, val}
end
end
end
defp connect_session(conn, endpoint, {key, store, store_config}) do
conn = Plug.Conn.fetch_cookies(conn)
with csrf_token when is_binary(csrf_token) <- conn.params["_csrf_token"],
cookie when is_binary(cookie) <- conn.cookies[key],
conn = put_in(conn.secret_key_base, endpoint.config(:secret_key_base)),
{_, session} <- store.get(conn, cookie, store_config),
csrf_state when is_binary(csrf_state) <- Plug.CSRFProtection.dump_state_from_session(session["_csrf_token"]),
true <- Plug.CSRFProtection.valid_state_and_csrf_token?(csrf_state, csrf_token) do
session
else
_ -> nil
end
end
defp connect_session(conn, endpoint, {:mfa, {module, function, args}}) do
case apply(module, function, args) do
session_config when is_list(session_config) ->
connect_session(conn, endpoint, init_session(session_config))
other ->
raise ArgumentError,
"the MFA given to `session_config` must return a keyword list, got: #{inspect other}"
end
end
defp subprotocols_error_response(conn, subprotocols) do
import Plug.Conn
request_headers = get_req_header(conn, "sec-websocket-protocol")
Logger.error """
Could not check Websocket subprotocols for Phoenix.Socket transport.
Subprotocols of the request: #{inspect(request_headers)}
Configured supported subprotocols: #{inspect(subprotocols)}
This happens when you are attempting a socket connection to
a different subprotocols than the one configured in your endpoint
or when you incorrectly configured supported subprotocols.
To fix this issue, you may either:
1. update websocket: [subprotocols: [..]] to your actual subprotocols
in your endpoint socket configuration.
2. check the correctness of the `sec-websocket-protocol` request header
sent from the client.
3. remove `websocket` option from your endpoint socket configuration
if you don't use Websocket subprotocols.
"""
resp(conn, :forbidden, "")
|> send_resp()
|> halt()
end
defp fetch_x_headers(conn) do
for {header, _} = pair <- conn.req_headers,
String.starts_with?(header, "x-"),
do: pair
end
defp fetch_trace_context_headers(conn) do
for {header, _} = pair <- conn.req_headers,
header in ["traceparent", "tracestate"],
do: pair
end
defp fetch_uri(conn) do
%URI{
scheme: to_string(conn.scheme),
query: conn.query_string,
port: conn.port,
host: conn.host,
authority: conn.host,
path: conn.request_path
}
end
defp fetch_user_agent(conn) do
with {_, value} <- List.keyfind(conn.req_headers, "user-agent", 0) do
value
end
end
defp check_origin_config(handler, endpoint, opts) do
Phoenix.Config.cache(endpoint, {:check_origin, handler}, fn _ ->
check_origin =
case Keyword.get(opts, :check_origin, endpoint.config(:check_origin)) do
origins when is_list(origins) ->
Enum.map(origins, &parse_origin/1)
boolean when is_boolean(boolean) ->
boolean
{module, function, arguments} ->
{module, function, arguments}
:conn ->
:conn
invalid ->
raise ArgumentError, ":check_origin expects a boolean, list of hosts, :conn, or MFA tuple, got: #{inspect(invalid)}"
end
{:cache, check_origin}
end)
end
defp parse_origin(origin) do
case URI.parse(origin) do
%{host: nil} ->
raise ArgumentError,
"invalid :check_origin option: #{inspect origin}. " <>
"Expected an origin with a host that is parsable by URI.parse/1. For example: " <>
"[\"https://example.com\", \"//another.com:888\", \"//other.com\"]"
%{scheme: scheme, port: port, host: host} ->
{scheme, host, port}
end
end
defp origin_allowed?({module, function, arguments}, uri, _endpoint, _conn),
do: apply(module, function, [uri | arguments])
defp origin_allowed?(:conn, uri, _endpoint, %Plug.Conn{} = conn) do
uri.host == conn.host and
uri.scheme == Atom.to_string(conn.scheme) and
uri.port == conn.port
end
defp origin_allowed?(_check_origin, %{host: nil}, _endpoint, _conn),
do: false
defp origin_allowed?(true, uri, endpoint, _conn),
do: compare?(uri.host, host_to_binary(endpoint.config(:url)[:host]))
defp origin_allowed?(check_origin, uri, _endpoint, _conn) when is_list(check_origin),
do: origin_allowed?(uri, check_origin)
defp origin_allowed?(uri, allowed_origins) do
%{scheme: origin_scheme, host: origin_host, port: origin_port} = uri
Enum.any?(allowed_origins, fn {allowed_scheme, allowed_host, allowed_port} ->
compare?(origin_scheme, allowed_scheme) and
compare?(origin_port, allowed_port) and
compare_host?(origin_host, allowed_host)
end)
end
defp compare?(request_val, allowed_val) do
is_nil(allowed_val) or request_val == allowed_val
end
defp compare_host?(_request_host, nil),
do: true
defp compare_host?(request_host, "*." <> allowed_host),
do: request_host == allowed_host or String.ends_with?(request_host, "." <> allowed_host)
defp compare_host?(request_host, allowed_host),
do: request_host == allowed_host
# TODO: Deprecate {:system, env_var} once we require Elixir v1.9+
defp host_to_binary({:system, env_var}), do: host_to_binary(System.get_env(env_var))
defp host_to_binary(host), do: host
end