//// Implementation of https://www.rfc-editor.org/rfc/rfc5321.html
import gleam/bit_array
import gleam/bool
import gleam/crypto
import gleam/function
import gleam/int
import gleam/list
import gleam/option.{type Option, None, Some}
import gleam/order
import gleam/result
import gleam/string
import internal/encoder/encoding
import internal/encoder/idna
import internal/renderer/internet_message
import sendr/message.{type Message}
import sendr/message/mailbox.{type Mailbox}
type SaslMechanism {
CramMd5
Login
Plain
XOauth2
}
type Extensions {
Extensions(
maximum_data_size: Int,
command_encoding_mode: encoding.EncodingMode,
data_encoding_mode: encoding.EncodingMode,
start_tls: Bool,
auth: List(SaslMechanism),
)
}
pub type ArgumentError {
EncodingError(encoding.EncoderError)
InvalidAddress(String)
NoFromMailboxSpecified
NoRecipientsSpecified
DataSizeError(actual: Int, allowed: Int)
DataRenderError(internet_message.RenderError)
}
/// Errors that can occur during SMTP protocol communication.
pub type ProtocolError {
InvalidResponse(response: String)
UnexpectedResponseCode(expected: Int, actual: Int)
InvalidRequestError(ArgumentError)
}
/// Callback type for the `Send` action. Called after a command is sent.
pub type SendCallback =
fn() -> Result(Action, ProtocolError)
/// Callback type for the `Receive` action. Called with the server response.
pub type ReceiveCallback =
fn(String) -> Result(Action, ProtocolError)
/// Represents the next action to take in the SMTP protocol state machine.
///
/// - `Send(cmd, next)`: Send an SMTP command, then continue with `next`.
/// - `Receive(callback)`: Wait for an SMTP response, then call `callback`.
/// - `Upgrade(next)`: Upgrade the connection to TLS, then continue with `next`.
/// - `Done`: The protocol session has completed successfully.
pub type Action {
Send(String, SendCallback)
Receive(ReceiveCallback)
Upgrade(SendCallback)
Done
}
/// Configuration for an SMTP protocol session.
///
/// - `helo_host`: The hostname to use in the EHLO/HELO command.
/// - `credentials`: Optional SMTP authentication credentials.
pub type ProtocolConfig {
ProtocolConfig(helo_host: String, credentials: Option(#(String, String)))
}
type Receiver {
Receiver(
expected_code: Int,
mapper: fn(List(String)) -> Result(List(String), ProtocolError),
recover: fn(ProtocolError) -> Result(Action, ProtocolError),
)
}
/// Start an SMTP protocol session for sending a message.
///
/// Executes the full SMTP session: server greeting, EHLO/HELO, STARTTLS,
/// authentication (if configured), and mail transaction (MAIL FROM, RCPT TO, DATA).
///
/// - `message`: The `sendr/message.Message` to deliver.
/// - `config`: The `ProtocolConfig` for the session.
///
/// Returns an `Action` representing the first action in the protocol sequence.
pub fn start_session(
message message: Message,
config config: ProtocolConfig,
) -> Result(Action, ProtocolError) {
use _ <- result.try(has_from_and_recipients(message))
use has_smtputf8 <- result.try(has_smtputf8(message))
use has_8bitmime <- result.try(has_8bitmime_body(message))
use encoded_helo_host <- result.try(encode_domain(config.helo_host))
{
let initiate_client = fn(next) {
initiate_client(encoded_helo_host, has_smtputf8, has_8bitmime, next)
}
use <- initiate_server()
use extensions <- initiate_client()
use extensions <- start_tls(extensions, initiate_client)
use <- authenticate(extensions, config.credentials)
use <- mail_transaction(extensions, message)
quit()
}
}
fn has_from_and_recipients(message: Message) -> Result(Nil, ProtocolError) {
case message.from, recipients(message) {
None, _ -> Error(InvalidRequestError(NoFromMailboxSpecified))
_, [] -> Error(InvalidRequestError(NoRecipientsSpecified))
_, _ -> Ok(Nil)
}
}
fn has_smtputf8(message: Message) -> Result(Bool, ProtocolError) {
[option.unwrap(message.from, mailbox.empty), ..recipients(message)]
|> list.map(fn(mailbox) {
string.split_once(mailbox.address, "@")
|> result.map(fn(address) {
let #(local, _domain) = address
!is_ascii_string(local)
})
|> result.replace_error(
InvalidRequestError(InvalidAddress(mailbox.address)),
)
})
|> result.all()
|> result.map(list.any(_, function.identity))
}
fn has_8bitmime_body(message: Message) -> Result(Bool, ProtocolError) {
option.values([
option.map(message.body.html, is_ascii_string),
option.map(message.body.text, is_ascii_string),
])
|> list.any(function.identity)
|> bool.negate()
|> Ok
}
fn encode_domain(domain: String) -> Result(String, ProtocolError) {
domain
|> idna.encode_domain(encoding.Ascii)
|> result.map_error(fn(error) { InvalidRequestError(EncodingError(error)) })
}
fn initiate_server(
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
use <- ignore(receive(220))
next()
}
fn initiate_client(
helo_host: String,
has_smtputf8: Bool,
has_8bitmime: Bool,
next: fn(Extensions) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let next = fn(extensions) {
extensions
|> parse_extensions(has_smtputf8, has_8bitmime)
|> next()
}
use <- send("EHLO " <> helo_host)
use extensions <- recover(receive(250), fn(error) {
case error {
UnexpectedResponseCode(actual: 500, ..) -> {
use <- send("HELO " <> helo_host)
use <- ignore(receive(250))
next([])
}
UnexpectedResponseCode(..) as error
| InvalidResponse(..) as error
| InvalidRequestError(..) as error -> Error(error)
}
})
next(extensions)
}
fn start_tls(
extensions: Extensions,
initiate_client: fn(fn(Extensions) -> Result(Action, ProtocolError)) ->
Result(Action, ProtocolError),
next: fn(Extensions) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
case extensions.start_tls {
False -> next(extensions)
True -> {
use <- send("STARTTLS")
use <- ignore(receive(220))
use <- upgrade()
use extensions <- initiate_client()
next(extensions)
}
}
}
fn authenticate(
extensions: Extensions,
credentials: Option(#(String, String)),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
case credentials, extensions.auth {
None, _ | Some(_), [] -> next()
Some(credentials), [CramMd5, ..] -> authenticate_cram_md5(credentials, next)
Some(credentials), [Login, ..] -> authenticate_login(credentials, next)
Some(credentials), [Plain, ..] -> authenticate_plain(credentials, next)
Some(credentials), [XOauth2, ..] -> authenticate_xoauth2(credentials, next)
}
}
fn authenticate_cram_md5(
credentials: #(String, String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let #(username, password) = credentials
let encode_credentials = fn(challenge) {
let challenge = string.join(challenge, "")
challenge
|> bit_array.base64_decode()
|> result.map(crypto.hmac(_, crypto.Md5, <<password:utf8>>))
|> result.map(bit_array.base16_encode)
|> result.map(string.lowercase)
|> result.map(string.append(username <> " ", _))
|> result.map(fn(response) {
bit_array.base64_encode(<<response:utf8>>, True)
})
|> result.map(fn(encoded_credentials) { [encoded_credentials] })
|> result.replace_error(InvalidResponse(challenge))
}
use <- send("AUTH CRAM-MD5")
use encoded_credentials <- map(receive(334), encode_credentials)
use <- send(string.join(encoded_credentials, ""))
use <- ignore(receive(235))
next()
}
fn authenticate_login(
credentials: #(String, String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let #(username, password) = credentials
let encoded_username = bit_array.base64_encode(<<username:utf8>>, True)
let encoded_password = bit_array.base64_encode(<<password:utf8>>, True)
use <- send("AUTH LOGIN")
use <- check(receive(334), ["VXNlcm5hbWU6", "dXNlcm5hbWU6"])
use <- send(encoded_username)
use <- check(receive(334), ["UGFzc3dvcmQ6", "cGFzc3dvcmQ6"])
use <- send(encoded_password)
use <- ignore(receive(235))
next()
}
fn authenticate_plain(
credentials: #(String, String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let #(username, password) = credentials
let encoded =
bit_array.base64_encode(<<0, username:utf8, 0, password:utf8>>, True)
use <- send("AUTH PLAIN " <> encoded)
use <- ignore(receive(235))
next()
}
fn authenticate_xoauth2(
credentials: #(String, String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let #(username, token) = credentials
let encoded =
bit_array.base64_encode(
<<"user=", username:utf8, 1, "auth=Bearer ", token:utf8, 1, 1>>,
True,
)
use <- send("AUTH XOAUTH2 " <> encoded)
use <- ignore(receive(235))
next()
}
fn parse_extensions(
extensions: List(String),
has_smtputf8: Bool,
has_8bitmime: Bool,
) -> Extensions {
extensions
|> list.drop(1)
|> list.fold(
Extensions(
auth: [],
maximum_data_size: 0,
command_encoding_mode: encoding.Ascii,
data_encoding_mode: encoding.Ascii,
start_tls: False,
),
fn(extensions, line) {
case string.uppercase(line) {
"8BITMIME" if has_8bitmime ->
Extensions(..extensions, data_encoding_mode: encoding.Utf8)
"AUTH " <> mechanisms ->
mechanisms
|> string.split(" ")
|> list.fold(extensions, fn(extensions, mechanism) {
case mechanism {
"CRAM-MD5" ->
Extensions(
..extensions,
auth: list.append(extensions.auth, [CramMd5]),
)
"LOGIN" ->
Extensions(
..extensions,
auth: list.append(extensions.auth, [Login]),
)
"PLAIN" ->
Extensions(
..extensions,
auth: list.append(extensions.auth, [Plain]),
)
"XOAUTH2" ->
Extensions(
..extensions,
auth: list.append(extensions.auth, [XOauth2]),
)
_ -> extensions
}
})
"SIZE " <> size ->
Extensions(
..extensions,
maximum_data_size: result.unwrap(int.parse(size), 0),
)
"SMTPUTF8" if has_smtputf8 ->
Extensions(..extensions, command_encoding_mode: encoding.Utf8)
"STARTTLS" -> Extensions(..extensions, start_tls: True)
_ -> extensions
}
},
)
}
fn recipients(message: Message) -> List(Mailbox) {
[message.to, message.cc, message.bcc]
|> option.values()
|> list.flatten()
}
fn mail_transaction(
extensions: Extensions,
message: Message,
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
use <- mail_from(message, extensions)
use <- mail_recipients(message, extensions)
use <- mail_data(message, extensions)
next()
}
fn mail_from(
message: Message,
extensions: Extensions,
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let smtp_utf8 = case extensions.command_encoding_mode {
encoding.Ascii -> None
encoding.Utf8 -> Some("SMTPUTF8")
}
let body_8bitmime = case extensions.data_encoding_mode {
encoding.Ascii -> None
encoding.Utf8 -> Some("BODY=8BITMIME")
}
let address =
message.from
|> option.to_result(NoFromMailboxSpecified)
|> result.try(validate_mailbox(_, extensions))
|> result.try(fn(mailbox) {
idna.encode_email_address(
mailbox.address,
extensions.command_encoding_mode,
)
|> result.map(fn(address) {
[Some("<" <> address <> ">"), smtp_utf8, body_8bitmime]
|> option.values()
|> string.join(" ")
})
|> result.map_error(EncodingError)
})
|> result.map_error(InvalidRequestError)
use address <- result.try(address)
use <- send("MAIL FROM:" <> address)
use <- ignore(receive(250))
next()
}
fn mail_recipients(
message: Message,
extensions: Extensions,
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let addresses =
message
|> recipients()
|> list.map(fn(mailbox) {
mailbox
|> validate_mailbox(extensions)
|> result.try(fn(mailbox) {
idna.encode_email_address(
mailbox.address,
extensions.command_encoding_mode,
)
|> result.map_error(EncodingError)
})
|> result.map_error(InvalidRequestError)
})
|> result.all()
|> result.try(fn(recipients) {
case recipients {
[] -> Error(InvalidRequestError(NoRecipientsSpecified))
_ -> Ok(recipients)
}
})
use addresses <- result.try(addresses)
do_mail_recipients(addresses, next)
}
fn do_mail_recipients(
addresses: List(String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
case addresses {
[address, ..rest_addresses] -> {
use <- send("RCPT TO:<" <> address <> ">")
use <- ignore(receive(250))
do_mail_recipients(rest_addresses, next)
}
[] -> next()
}
}
fn mail_data(
message: Message,
extensions: Extensions,
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let data =
message
|> internet_message.encode(extensions.data_encoding_mode)
|> result.map_error(DataRenderError)
|> result.try(fn(body) {
let body_size =
list.fold(body, 0, fn(accumulator, line) {
accumulator + string.byte_size(line)
})
case extensions.maximum_data_size {
size if size == 0 || body_size <= size -> Ok(body)
size -> Error(DataSizeError(body_size, size))
}
})
|> result.map_error(InvalidRequestError)
use data <- result.try(data)
use <- send("DATA")
use <- ignore(receive(354))
use <- do_mail_data(data)
use <- send(".")
use <- ignore(receive(250))
next()
}
fn do_mail_data(
data: List(String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
case data {
[data, ..rest_data] -> {
let data = case string.starts_with(data, ".") {
True -> "." <> data
False -> data
}
use <- send(data)
do_mail_data(rest_data, next)
}
[] -> next()
}
}
fn quit() -> Result(Action, ProtocolError) {
use <- send("QUIT")
use <- ignore(receive(221))
Ok(Done)
}
fn upgrade(
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
Ok(Upgrade(next))
}
fn send(
command: String,
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
Ok(Send(command <> "\r\n", next))
}
fn ignore(
receiver: Receiver,
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
do_receive(
receiver.expected_code,
Ok([]),
receiver.mapper,
receiver.recover,
fn(_) { next() },
)
}
fn check(
receiver: Receiver,
valid_responses: List(String),
next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let mapper = fn(response) {
case list.all(response, list.contains(valid_responses, _)) {
False -> Error(InvalidResponse(string.join(response, "\n")))
True -> Ok(response)
}
}
do_receive(receiver.expected_code, Ok([]), mapper, receiver.recover, fn(_) {
next()
})
}
fn map(
receiver: Receiver,
mapper: fn(List(String)) -> Result(List(String), ProtocolError),
next: fn(List(String)) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
do_receive(receiver.expected_code, Ok([]), mapper, receiver.recover, next)
}
fn recover(
receiver: Receiver,
recover: fn(ProtocolError) -> Result(Action, ProtocolError),
next: fn(List(String)) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
do_receive(receiver.expected_code, Ok([]), receiver.mapper, recover, next)
}
fn receive(expected_code: Int) -> Receiver {
Receiver(expected_code:, mapper: Ok, recover: Error)
}
fn do_receive(
expected_code: Int,
accumulator: Result(List(String), ProtocolError),
mapper: fn(List(String)) -> Result(List(String), ProtocolError),
recover: fn(ProtocolError) -> Result(Action, ProtocolError),
next: fn(List(String)) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
let callback = fn(response) {
case parse_response(response), accumulator {
Ok(#(True, _, _)), Error(_) as error ->
do_receive(expected_code, error, mapper, recover, next)
Ok(#(False, _, _)), Error(error) -> recover(error)
Ok(#(True, response_code, _)), Ok(_) if response_code != expected_code ->
do_receive(
expected_code,
Error(UnexpectedResponseCode(
expected: expected_code,
actual: response_code,
)),
mapper,
recover,
next,
)
Ok(#(False, response_code, _)), Ok(_) if response_code != expected_code ->
recover(UnexpectedResponseCode(
expected: expected_code,
actual: response_code,
))
Ok(#(True, _, message)), Ok(messages) ->
do_receive(
expected_code,
Ok([message, ..messages]),
mapper,
recover,
next,
)
Ok(#(False, _, message)), Ok(messages) ->
list.reverse([message, ..messages])
|> mapper()
|> result.try(next)
Error(error), _ -> recover(error)
}
}
Ok(Receive(callback))
}
fn parse_response(
response: String,
) -> Result(#(Bool, Int, String), ProtocolError) {
{
use response_code <- result.try(int.parse(string.slice(response, 0, 3)))
use need_more <- result.map(case string.slice(response, 3, 1) {
" " | "" -> Ok(False)
"-" -> Ok(True)
_ -> Error(Nil)
})
#(
need_more,
response_code,
response |> string.drop_start(4) |> string.remove_suffix("\r\n"),
)
}
|> result.replace_error(InvalidResponse(response))
}
fn validate_mailbox(
mailbox: Mailbox,
extensions: Extensions,
) -> Result(Mailbox, ArgumentError) {
case string.split_once(mailbox.address, "@") {
Ok(#(local, _domain)) ->
case
extensions.command_encoding_mode == encoding.Utf8
|| is_ascii_string(local)
{
True -> Ok(mailbox)
False -> Error(InvalidAddress(mailbox.address))
}
Error(_) -> Error(InvalidAddress(mailbox.address))
}
}
fn is_ascii_string(string: String) -> Bool {
string
|> string.split("")
|> list.all(fn(char) {
string.compare(char, "\u{1f}") == order.Gt
&& string.compare(char, "\u{7f}") == order.Lt
})
}