Skip to main content

src/internal/protocol.gleam

//// Implementation of https://www.rfc-editor.org/rfc/rfc5321.html

import gleam/bit_array
import gleam/bool
import gleam/crypto
import gleam/function
import gleam/int
import gleam/list
import gleam/option.{type Option, None, Some}
import gleam/order
import gleam/result
import gleam/string
import internal/encoder/encoding
import internal/encoder/idna
import internal/renderer/internet_message
import sendr/message.{type Message}
import sendr/message/mailbox.{type Mailbox}

type SaslMechanism {
  CramMd5
  Login
  Plain
  XOauth2
}

type Extensions {
  Extensions(
    maximum_data_size: Int,
    command_encoding_mode: encoding.EncodingMode,
    data_encoding_mode: encoding.EncodingMode,
    start_tls: Bool,
    auth: List(SaslMechanism),
  )
}

pub type ArgumentError {
  EncodingError(encoding.EncoderError)
  InvalidAddress(String)
  NoFromMailboxSpecified
  NoRecipientsSpecified
  DataSizeError(actual: Int, allowed: Int)
  DataRenderError(internet_message.RenderError)
}

/// Errors that can occur during SMTP protocol communication.
pub type ProtocolError {
  InvalidResponse(response: String)
  UnexpectedResponseCode(expected: Int, actual: Int)
  InvalidRequestError(ArgumentError)
}

/// Callback type for the `Send` action. Called after a command is sent.
pub type SendCallback =
  fn() -> Result(Action, ProtocolError)

/// Callback type for the `Receive` action. Called with the server response.
pub type ReceiveCallback =
  fn(String) -> Result(Action, ProtocolError)

/// Represents the next action to take in the SMTP protocol state machine.
///
/// - `Send(cmd, next)`: Send an SMTP command, then continue with `next`.
/// - `Receive(callback)`: Wait for an SMTP response, then call `callback`.
/// - `Upgrade(next)`: Upgrade the connection to TLS, then continue with `next`.
/// - `Done`: The protocol session has completed successfully.
pub type Action {
  Send(String, SendCallback)
  Receive(ReceiveCallback)
  Upgrade(SendCallback)
  Done
}

/// Configuration for an SMTP protocol session.
///
/// - `helo_host`: The hostname to use in the EHLO/HELO command.
/// - `credentials`: Optional SMTP authentication credentials.
pub type ProtocolConfig {
  ProtocolConfig(helo_host: String, credentials: Option(#(String, String)))
}

type Receiver {
  Receiver(
    expected_code: Int,
    mapper: fn(List(String)) -> Result(List(String), ProtocolError),
    recover: fn(ProtocolError) -> Result(Action, ProtocolError),
  )
}

/// Start an SMTP protocol session for sending a message.
///
/// Executes the full SMTP session: server greeting, EHLO/HELO, STARTTLS,
/// authentication (if configured), and mail transaction (MAIL FROM, RCPT TO, DATA).
///
/// - `message`: The `sendr/message.Message` to deliver.
/// - `config`: The `ProtocolConfig` for the session.
///
/// Returns an `Action` representing the first action in the protocol sequence.
pub fn start_session(
  message message: Message,
  config config: ProtocolConfig,
) -> Result(Action, ProtocolError) {
  use _ <- result.try(has_from_and_recipients(message))
  use has_smtputf8 <- result.try(has_smtputf8(message))
  use has_8bitmime <- result.try(has_8bitmime_body(message))
  use encoded_helo_host <- result.try(encode_domain(config.helo_host))
  {
    let initiate_client = fn(next) {
      initiate_client(encoded_helo_host, has_smtputf8, has_8bitmime, next)
    }
    use <- initiate_server()
    use extensions <- initiate_client()
    use extensions <- start_tls(extensions, initiate_client)
    use <- authenticate(extensions, config.credentials)
    use <- mail_transaction(extensions, message)
    quit()
  }
}

fn has_from_and_recipients(message: Message) -> Result(Nil, ProtocolError) {
  case message.from, recipients(message) {
    None, _ -> Error(InvalidRequestError(NoFromMailboxSpecified))
    _, [] -> Error(InvalidRequestError(NoRecipientsSpecified))
    _, _ -> Ok(Nil)
  }
}

fn has_smtputf8(message: Message) -> Result(Bool, ProtocolError) {
  [option.unwrap(message.from, mailbox.empty), ..recipients(message)]
  |> list.map(fn(mailbox) {
    string.split_once(mailbox.address, "@")
    |> result.map(fn(address) {
      let #(local, _domain) = address
      !is_ascii_string(local)
    })
    |> result.replace_error(
      InvalidRequestError(InvalidAddress(mailbox.address)),
    )
  })
  |> result.all()
  |> result.map(list.any(_, function.identity))
}

fn has_8bitmime_body(message: Message) -> Result(Bool, ProtocolError) {
  option.values([
    option.map(message.body.html, is_ascii_string),
    option.map(message.body.text, is_ascii_string),
  ])
  |> list.any(function.identity)
  |> bool.negate()
  |> Ok
}

fn encode_domain(domain: String) -> Result(String, ProtocolError) {
  domain
  |> idna.encode_domain(encoding.Ascii)
  |> result.map_error(fn(error) { InvalidRequestError(EncodingError(error)) })
}

fn initiate_server(
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  use <- ignore(receive(220))
  next()
}

fn initiate_client(
  helo_host: String,
  has_smtputf8: Bool,
  has_8bitmime: Bool,
  next: fn(Extensions) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let next = fn(extensions) {
    extensions
    |> parse_extensions(has_smtputf8, has_8bitmime)
    |> next()
  }

  use <- send("EHLO " <> helo_host)
  use extensions <- recover(receive(250), fn(error) {
    case error {
      UnexpectedResponseCode(actual: 500, ..) -> {
        use <- send("HELO " <> helo_host)
        use <- ignore(receive(250))
        next([])
      }
      UnexpectedResponseCode(..) as error
      | InvalidResponse(..) as error
      | InvalidRequestError(..) as error -> Error(error)
    }
  })
  next(extensions)
}

fn start_tls(
  extensions: Extensions,
  initiate_client: fn(fn(Extensions) -> Result(Action, ProtocolError)) ->
    Result(Action, ProtocolError),
  next: fn(Extensions) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  case extensions.start_tls {
    False -> next(extensions)
    True -> {
      use <- send("STARTTLS")
      use <- ignore(receive(220))
      use <- upgrade()
      use extensions <- initiate_client()
      next(extensions)
    }
  }
}

fn authenticate(
  extensions: Extensions,
  credentials: Option(#(String, String)),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  case credentials, extensions.auth {
    None, _ | Some(_), [] -> next()
    Some(credentials), [CramMd5, ..] -> authenticate_cram_md5(credentials, next)
    Some(credentials), [Login, ..] -> authenticate_login(credentials, next)
    Some(credentials), [Plain, ..] -> authenticate_plain(credentials, next)
    Some(credentials), [XOauth2, ..] -> authenticate_xoauth2(credentials, next)
  }
}

fn authenticate_cram_md5(
  credentials: #(String, String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let #(username, password) = credentials

  let encode_credentials = fn(challenge) {
    let challenge = string.join(challenge, "")

    challenge
    |> bit_array.base64_decode()
    |> result.map(crypto.hmac(_, crypto.Md5, <<password:utf8>>))
    |> result.map(bit_array.base16_encode)
    |> result.map(string.lowercase)
    |> result.map(string.append(username <> " ", _))
    |> result.map(fn(response) {
      bit_array.base64_encode(<<response:utf8>>, True)
    })
    |> result.map(fn(encoded_credentials) { [encoded_credentials] })
    |> result.replace_error(InvalidResponse(challenge))
  }

  use <- send("AUTH CRAM-MD5")
  use encoded_credentials <- map(receive(334), encode_credentials)
  use <- send(string.join(encoded_credentials, ""))
  use <- ignore(receive(235))
  next()
}

fn authenticate_login(
  credentials: #(String, String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let #(username, password) = credentials
  let encoded_username = bit_array.base64_encode(<<username:utf8>>, True)
  let encoded_password = bit_array.base64_encode(<<password:utf8>>, True)

  use <- send("AUTH LOGIN")
  use <- check(receive(334), ["VXNlcm5hbWU6", "dXNlcm5hbWU6"])
  use <- send(encoded_username)
  use <- check(receive(334), ["UGFzc3dvcmQ6", "cGFzc3dvcmQ6"])
  use <- send(encoded_password)
  use <- ignore(receive(235))
  next()
}

fn authenticate_plain(
  credentials: #(String, String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let #(username, password) = credentials
  let encoded =
    bit_array.base64_encode(<<0, username:utf8, 0, password:utf8>>, True)

  use <- send("AUTH PLAIN " <> encoded)
  use <- ignore(receive(235))
  next()
}

fn authenticate_xoauth2(
  credentials: #(String, String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let #(username, token) = credentials
  let encoded =
    bit_array.base64_encode(
      <<"user=", username:utf8, 1, "auth=Bearer ", token:utf8, 1, 1>>,
      True,
    )

  use <- send("AUTH XOAUTH2 " <> encoded)
  use <- ignore(receive(235))
  next()
}

fn parse_extensions(
  extensions: List(String),
  has_smtputf8: Bool,
  has_8bitmime: Bool,
) -> Extensions {
  extensions
  |> list.drop(1)
  |> list.fold(
    Extensions(
      auth: [],
      maximum_data_size: 0,
      command_encoding_mode: encoding.Ascii,
      data_encoding_mode: encoding.Ascii,
      start_tls: False,
    ),
    fn(extensions, line) {
      case string.uppercase(line) {
        "8BITMIME" if has_8bitmime ->
          Extensions(..extensions, data_encoding_mode: encoding.Utf8)
        "AUTH " <> mechanisms ->
          mechanisms
          |> string.split(" ")
          |> list.fold(extensions, fn(extensions, mechanism) {
            case mechanism {
              "CRAM-MD5" ->
                Extensions(
                  ..extensions,
                  auth: list.append(extensions.auth, [CramMd5]),
                )
              "LOGIN" ->
                Extensions(
                  ..extensions,
                  auth: list.append(extensions.auth, [Login]),
                )
              "PLAIN" ->
                Extensions(
                  ..extensions,
                  auth: list.append(extensions.auth, [Plain]),
                )
              "XOAUTH2" ->
                Extensions(
                  ..extensions,
                  auth: list.append(extensions.auth, [XOauth2]),
                )
              _ -> extensions
            }
          })
        "SIZE " <> size ->
          Extensions(
            ..extensions,
            maximum_data_size: result.unwrap(int.parse(size), 0),
          )
        "SMTPUTF8" if has_smtputf8 ->
          Extensions(..extensions, command_encoding_mode: encoding.Utf8)
        "STARTTLS" -> Extensions(..extensions, start_tls: True)
        _ -> extensions
      }
    },
  )
}

fn recipients(message: Message) -> List(Mailbox) {
  [message.to, message.cc, message.bcc]
  |> option.values()
  |> list.flatten()
}

fn mail_transaction(
  extensions: Extensions,
  message: Message,
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  use <- mail_from(message, extensions)
  use <- mail_recipients(message, extensions)
  use <- mail_data(message, extensions)
  next()
}

fn mail_from(
  message: Message,
  extensions: Extensions,
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let smtp_utf8 = case extensions.command_encoding_mode {
    encoding.Ascii -> None
    encoding.Utf8 -> Some("SMTPUTF8")
  }

  let body_8bitmime = case extensions.data_encoding_mode {
    encoding.Ascii -> None
    encoding.Utf8 -> Some("BODY=8BITMIME")
  }

  let address =
    message.from
    |> option.to_result(NoFromMailboxSpecified)
    |> result.try(validate_mailbox(_, extensions))
    |> result.try(fn(mailbox) {
      idna.encode_email_address(
        mailbox.address,
        extensions.command_encoding_mode,
      )
      |> result.map(fn(address) {
        [Some("<" <> address <> ">"), smtp_utf8, body_8bitmime]
        |> option.values()
        |> string.join(" ")
      })
      |> result.map_error(EncodingError)
    })
    |> result.map_error(InvalidRequestError)

  use address <- result.try(address)
  use <- send("MAIL FROM:" <> address)
  use <- ignore(receive(250))
  next()
}

fn mail_recipients(
  message: Message,
  extensions: Extensions,
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let addresses =
    message
    |> recipients()
    |> list.map(fn(mailbox) {
      mailbox
      |> validate_mailbox(extensions)
      |> result.try(fn(mailbox) {
        idna.encode_email_address(
          mailbox.address,
          extensions.command_encoding_mode,
        )
        |> result.map_error(EncodingError)
      })
      |> result.map_error(InvalidRequestError)
    })
    |> result.all()
    |> result.try(fn(recipients) {
      case recipients {
        [] -> Error(InvalidRequestError(NoRecipientsSpecified))
        _ -> Ok(recipients)
      }
    })

  use addresses <- result.try(addresses)
  do_mail_recipients(addresses, next)
}

fn do_mail_recipients(
  addresses: List(String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  case addresses {
    [address, ..rest_addresses] -> {
      use <- send("RCPT TO:<" <> address <> ">")
      use <- ignore(receive(250))
      do_mail_recipients(rest_addresses, next)
    }
    [] -> next()
  }
}

fn mail_data(
  message: Message,
  extensions: Extensions,
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let data =
    message
    |> internet_message.encode(extensions.data_encoding_mode)
    |> result.map_error(DataRenderError)
    |> result.try(fn(body) {
      let body_size =
        list.fold(body, 0, fn(accumulator, line) {
          accumulator + string.byte_size(line)
        })

      case extensions.maximum_data_size {
        size if size == 0 || body_size <= size -> Ok(body)
        size -> Error(DataSizeError(body_size, size))
      }
    })
    |> result.map_error(InvalidRequestError)

  use data <- result.try(data)
  use <- send("DATA")
  use <- ignore(receive(354))
  use <- do_mail_data(data)
  use <- send(".")
  use <- ignore(receive(250))
  next()
}

fn do_mail_data(
  data: List(String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  case data {
    [data, ..rest_data] -> {
      let data = case string.starts_with(data, ".") {
        True -> "." <> data
        False -> data
      }
      use <- send(data)
      do_mail_data(rest_data, next)
    }
    [] -> next()
  }
}

fn quit() -> Result(Action, ProtocolError) {
  use <- send("QUIT")
  use <- ignore(receive(221))
  Ok(Done)
}

fn upgrade(
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  Ok(Upgrade(next))
}

fn send(
  command: String,
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  Ok(Send(command <> "\r\n", next))
}

fn ignore(
  receiver: Receiver,
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  do_receive(
    receiver.expected_code,
    Ok([]),
    receiver.mapper,
    receiver.recover,
    fn(_) { next() },
  )
}

fn check(
  receiver: Receiver,
  valid_responses: List(String),
  next: fn() -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let mapper = fn(response) {
    case list.all(response, list.contains(valid_responses, _)) {
      False -> Error(InvalidResponse(string.join(response, "\n")))
      True -> Ok(response)
    }
  }

  do_receive(receiver.expected_code, Ok([]), mapper, receiver.recover, fn(_) {
    next()
  })
}

fn map(
  receiver: Receiver,
  mapper: fn(List(String)) -> Result(List(String), ProtocolError),
  next: fn(List(String)) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  do_receive(receiver.expected_code, Ok([]), mapper, receiver.recover, next)
}

fn recover(
  receiver: Receiver,
  recover: fn(ProtocolError) -> Result(Action, ProtocolError),
  next: fn(List(String)) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  do_receive(receiver.expected_code, Ok([]), receiver.mapper, recover, next)
}

fn receive(expected_code: Int) -> Receiver {
  Receiver(expected_code:, mapper: Ok, recover: Error)
}

fn do_receive(
  expected_code: Int,
  accumulator: Result(List(String), ProtocolError),
  mapper: fn(List(String)) -> Result(List(String), ProtocolError),
  recover: fn(ProtocolError) -> Result(Action, ProtocolError),
  next: fn(List(String)) -> Result(Action, ProtocolError),
) -> Result(Action, ProtocolError) {
  let callback = fn(response) {
    case parse_response(response), accumulator {
      Ok(#(True, _, _)), Error(_) as error ->
        do_receive(expected_code, error, mapper, recover, next)
      Ok(#(False, _, _)), Error(error) -> recover(error)
      Ok(#(True, response_code, _)), Ok(_) if response_code != expected_code ->
        do_receive(
          expected_code,
          Error(UnexpectedResponseCode(
            expected: expected_code,
            actual: response_code,
          )),
          mapper,
          recover,
          next,
        )
      Ok(#(False, response_code, _)), Ok(_) if response_code != expected_code ->
        recover(UnexpectedResponseCode(
          expected: expected_code,
          actual: response_code,
        ))
      Ok(#(True, _, message)), Ok(messages) ->
        do_receive(
          expected_code,
          Ok([message, ..messages]),
          mapper,
          recover,
          next,
        )
      Ok(#(False, _, message)), Ok(messages) ->
        list.reverse([message, ..messages])
        |> mapper()
        |> result.try(next)
      Error(error), _ -> recover(error)
    }
  }

  Ok(Receive(callback))
}

fn parse_response(
  response: String,
) -> Result(#(Bool, Int, String), ProtocolError) {
  {
    use response_code <- result.try(int.parse(string.slice(response, 0, 3)))
    use need_more <- result.map(case string.slice(response, 3, 1) {
      " " | "" -> Ok(False)
      "-" -> Ok(True)
      _ -> Error(Nil)
    })
    #(
      need_more,
      response_code,
      response |> string.drop_start(4) |> string.remove_suffix("\r\n"),
    )
  }
  |> result.replace_error(InvalidResponse(response))
}

fn validate_mailbox(
  mailbox: Mailbox,
  extensions: Extensions,
) -> Result(Mailbox, ArgumentError) {
  case string.split_once(mailbox.address, "@") {
    Ok(#(local, _domain)) ->
      case
        extensions.command_encoding_mode == encoding.Utf8
        || is_ascii_string(local)
      {
        True -> Ok(mailbox)
        False -> Error(InvalidAddress(mailbox.address))
      }
    Error(_) -> Error(InvalidAddress(mailbox.address))
  }
}

fn is_ascii_string(string: String) -> Bool {
  string
  |> string.split("")
  |> list.all(fn(char) {
    string.compare(char, "\u{1f}") == order.Gt
    && string.compare(char, "\u{7f}") == order.Lt
  })
}