Skip to main content

c_src/onnxruntime.cpp

#include <fine.hpp>
#include <onnxruntime_cxx_api.h>

#include <algorithm>
#include <cstring>
#include <limits>
#include <numeric>
#include <sstream>

namespace atoms {
auto cpu = fine::Atom("cpu");
auto f = fine::Atom("f");
auto bf = fine::Atom("bf");
auto s = fine::Atom("s");
auto u = fine::Atom("u");
auto pred = fine::Atom("pred");
} // namespace atoms

using DTypeTerm = std::tuple<fine::Atom, uint64_t>;
using ShapeTerm = std::vector<std::optional<int64_t>>;
using IOTerm = std::tuple<std::string, std::string, ShapeTerm>;

static Ort::Env &ort_env() {
  static Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "onnxruntime_elixir");
  return env;
}

static size_t element_size(ONNXTensorElementDataType type) {
  switch (type) {
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
    return 1;
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
    return 2;
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
    return 4;
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
    return 8;
  default:
    throw std::invalid_argument("unsupported ONNX tensor element type");
  }
}

static ONNXTensorElementDataType decode_dtype(const DTypeTerm &term) {
  auto [kind, bits] = term;
  auto name = kind.to_string();

  if (name == "f" && bits == 16)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
  if (name == "f" && bits == 32)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  if (name == "f" && bits == 64)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
  if (name == "bf" && bits == 16)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
  if (name == "s" && bits == 8)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
  if (name == "s" && bits == 16)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
  if (name == "s" && bits == 32)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
  if (name == "s" && bits == 64)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
  if (name == "u" && bits == 8)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
  if (name == "u" && bits == 16)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
  if (name == "u" && bits == 32)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
  if (name == "u" && bits == 64)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
  if (name == "pred" && bits == 8)
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;

  throw std::invalid_argument("unsupported Nx tensor type");
}

static DTypeTerm encode_dtype(ONNXTensorElementDataType type) {
  switch (type) {
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
    return {atoms::f, 16};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
    return {atoms::f, 32};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
    return {atoms::f, 64};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
    return {atoms::bf, 16};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
    return {atoms::s, 8};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
    return {atoms::s, 16};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
    return {atoms::s, 32};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
    return {atoms::s, 64};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
    return {atoms::u, 8};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
    return {atoms::u, 16};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
    return {atoms::u, 32};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
    return {atoms::u, 64};
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
    return {atoms::pred, 8};
  default:
    throw std::invalid_argument("unsupported ONNX tensor element type");
  }
}

static std::string dtype_name(ONNXTensorElementDataType type) {
  switch (type) {
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
    return "Float16";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
    return "Float32";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
    return "Float64";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
    return "BFloat16";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
    return "Int8";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
    return "Int16";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
    return "Int32";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
    return "Int64";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
    return "UInt8";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
    return "UInt16";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
    return "UInt32";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
    return "UInt64";
  case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
    return "Bool";
  default:
    return "Unknown";
  }
}

static int64_t element_count(const std::vector<int64_t> &shape) {
  if (shape.empty())
    return 1;

  return std::accumulate(shape.begin(), shape.end(), int64_t{1},
                         [](int64_t acc, int64_t dim) {
                           if (dim < 0)
                             throw std::invalid_argument(
                                 "tensor shape cannot contain dynamic dimensions");
                           return acc * dim;
                         });
}

class Tensor {
public:
  Tensor(ErlNifBinary binary, std::vector<int64_t> shape, DTypeTerm dtype_term)
      : shape(std::move(shape)), dtype(decode_dtype(dtype_term)),
        data(binary.data, binary.data + binary.size) {
    auto expected_size =
        static_cast<size_t>(element_count(this->shape)) * element_size(dtype);

    if (expected_size != data.size()) {
      std::ostringstream message;
      message << "binary size " << data.size() << " does not match tensor size "
              << expected_size;
      throw std::invalid_argument(message.str());
    }
  }

  explicit Tensor(Ort::Value &&output) {
    auto info = output.GetTensorTypeAndShapeInfo();
    dtype = info.GetElementType();
    shape = info.GetShape();

    auto bytes = info.GetElementCount() * element_size(dtype);
    data.resize(bytes);
    std::memcpy(data.data(), output.GetTensorRawData(), bytes);
  }

  Ort::Value ort_value() const {
    auto memory = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    return Ort::Value::CreateTensor(memory, const_cast<uint8_t *>(data.data()),
                                    data.size(), shape.data(), shape.size(),
                                    dtype);
  }

  const std::vector<int64_t> &dims() const { return shape; }
  ONNXTensorElementDataType element_type() const { return dtype; }

  fine::Term to_binary(ErlNifEnv *env, uint64_t limit) const {
    auto bytes = data.size();

    if (limit > 0) {
      auto max_bytes = limit * element_size(dtype);
      bytes = std::min<uint64_t>(bytes, max_bytes);
    }

    return fine::make_new_binary(
        env, reinterpret_cast<const char *>(data.data()), bytes);
  }

private:
  std::vector<int64_t> shape;
  ONNXTensorElementDataType dtype;
  std::vector<uint8_t> data;
};

FINE_RESOURCE(Tensor);

class Session {
public:
  Session(std::string path, std::vector<fine::Atom> execution_providers,
          int64_t optimization_level)
      : session(nullptr) {
    Ort::SessionOptions options;
    options.SetGraphOptimizationLevel(graph_optimization_level(optimization_level));

    for (auto &provider : execution_providers) {
      if (provider.to_string() != atoms::cpu.to_string()) {
        throw std::invalid_argument(
            "only the :cpu execution provider is currently supported");
      }
    }

    session = Ort::Session(ort_env(), path.c_str(), options);
    read_io_metadata();
  }

  std::vector<
      std::tuple<fine::ResourcePtr<Tensor>, std::vector<int64_t>, fine::Atom,
                 uint64_t>>
  run(std::vector<fine::ResourcePtr<Tensor>> inputs) {
    if (inputs.size() != input_name_ptrs.size()) {
      std::ostringstream message;
      message << "expected " << input_name_ptrs.size() << " inputs, got "
              << inputs.size();
      throw std::invalid_argument(message.str());
    }

    std::vector<Ort::Value> input_values;
    input_values.reserve(inputs.size());

    for (auto &input : inputs) {
      input_values.emplace_back(input->ort_value());
    }

    auto outputs = session.Run(Ort::RunOptions{nullptr}, input_name_ptrs.data(),
                               input_values.data(), input_values.size(),
                               output_name_ptrs.data(), output_name_ptrs.size());

    std::vector<std::tuple<fine::ResourcePtr<Tensor>, std::vector<int64_t>,
                           fine::Atom, uint64_t>>
        result;
    result.reserve(outputs.size());

    for (auto &output : outputs) {
      auto tensor = fine::make_resource<Tensor>(std::move(output));
      auto [kind, bits] = encode_dtype(tensor->element_type());
      result.emplace_back(tensor, tensor->dims(), kind, bits);
    }

    return result;
  }

  std::tuple<std::vector<IOTerm>, std::vector<IOTerm>> show() const {
    return {inputs, outputs};
  }

private:
  static GraphOptimizationLevel graph_optimization_level(int64_t level) {
    switch (level) {
    case 1:
      return ORT_ENABLE_BASIC;
    case 2:
      return ORT_ENABLE_EXTENDED;
    case 3:
      return ORT_ENABLE_ALL;
    default:
      return ORT_DISABLE_ALL;
    }
  }

  static ShapeTerm shape_term(const std::vector<int64_t> &shape) {
    ShapeTerm result;
    result.reserve(shape.size());

    for (auto dim : shape) {
      if (dim < 0) {
        result.push_back(std::nullopt);
      } else {
        result.push_back(dim);
      }
    }

    return result;
  }

  static IOTerm io_term(const std::string &name, const Ort::TypeInfo &type_info) {
    if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) {
      return {name, "NonTensor", {}};
    }

    auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
    return {name, dtype_name(tensor_info.GetElementType()),
            shape_term(tensor_info.GetShape())};
  }

  void read_io_metadata() {
    Ort::AllocatorWithDefaultOptions allocator;

    auto input_count = session.GetInputCount();
    input_names.reserve(input_count);
    input_name_ptrs.reserve(input_count);
    inputs.reserve(input_count);

    for (size_t i = 0; i < input_count; ++i) {
      auto name = session.GetInputNameAllocated(i, allocator);
      input_names.emplace_back(name.get());
      input_name_ptrs.push_back(input_names.back().c_str());
      inputs.push_back(io_term(input_names.back(), session.GetInputTypeInfo(i)));
    }

    auto output_count = session.GetOutputCount();
    output_names.reserve(output_count);
    output_name_ptrs.reserve(output_count);
    outputs.reserve(output_count);

    for (size_t i = 0; i < output_count; ++i) {
      auto name = session.GetOutputNameAllocated(i, allocator);
      output_names.emplace_back(name.get());
      output_name_ptrs.push_back(output_names.back().c_str());
      outputs.push_back(io_term(output_names.back(), session.GetOutputTypeInfo(i)));
    }
  }

  Ort::Session session;
  std::vector<std::string> input_names;
  std::vector<std::string> output_names;
  std::vector<const char *> input_name_ptrs;
  std::vector<const char *> output_name_ptrs;
  std::vector<IOTerm> inputs;
  std::vector<IOTerm> outputs;
};

FINE_RESOURCE(Session);

fine::ResourcePtr<Session> init(ErlNifEnv *, std::string path,
                                std::vector<fine::Atom> execution_providers,
                                int64_t optimization_level) {
  return fine::make_resource<Session>(std::move(path), execution_providers,
                                      optimization_level);
}

std::vector<std::tuple<fine::ResourcePtr<Tensor>, std::vector<int64_t>,
                       fine::Atom, uint64_t>>
run(ErlNifEnv *, fine::ResourcePtr<Session> session,
    std::vector<fine::ResourcePtr<Tensor>> inputs) {
  return session->run(std::move(inputs));
}

std::tuple<std::vector<IOTerm>, std::vector<IOTerm>>
show_session(ErlNifEnv *, fine::ResourcePtr<Session> session) {
  return session->show();
}

fine::ResourcePtr<Tensor> from_binary(ErlNifEnv *, ErlNifBinary binary,
                                      std::vector<int64_t> shape,
                                      DTypeTerm dtype) {
  return fine::make_resource<Tensor>(binary, std::move(shape), dtype);
}

fine::Term to_binary(ErlNifEnv *env, fine::ResourcePtr<Tensor> tensor,
                     uint64_t limit) {
  return tensor->to_binary(env, limit);
}

FINE_NIF(init, 0);
FINE_NIF(run, ERL_NIF_DIRTY_JOB_CPU_BOUND);
FINE_NIF(show_session, 0);
FINE_NIF(from_binary, 0);
FINE_NIF(to_binary, 0);

FINE_INIT("Elixir.OnnxRuntime.Native");