Skip to main content

examples/burn_bridge_ops.exs

# BurnBridge Direct Tensor Operations with ExBurn
# Run with: mix run examples/burn_bridge_ops.exs
#
# Demonstrates the ExBurn.BurnBridge module for direct Burn tensor
# operations, bypassing the Nx abstraction layer:
#   1. Creating tensors (zeros, ones, rand, from_nx)
#   2. Arithmetic operations (add, sub, mul, div, neg, abs)
#   3. Math functions (exp, log, sqrt, sigmoid, relu)
#   4. Linear algebra (matmul, transpose, dot)
#   5. Reductions (sum, mean)
#   6. Shape manipulation (reshape, softmax)
#   7. Loss functions (mse, cross_entropy)
#   8. Device management (gpu_available?)
#   9. Converting between Burn tensors and Nx tensors

Mix.install([
  {:nx, "~> 0.7"},
  {:axon, "~> 0.7"},
  {:ex_burn, path: Path.expand("..", __DIR__)}
])

defmodule BurnBridgeOps do
  @moduledoc """
  Showcases ExBurn.BurnBridge for direct tensor operations.
  """

  def run do
    IO.puts("=== BurnBridge Direct Tensor Operations ===\n")

    # 1. Device info
    IO.puts("Device:")
    IO.puts("  GPU available: #{ExBurn.BurnBridge.gpu_available?()}")
    IO.puts("  Device name:   #{ExBurn.BurnBridge.device_name()}")
    IO.puts("  Device info:   #{inspect(ExBurn.BurnBridge.device_info())}\n")

    # 2. Tensor creation
    IO.puts("Tensor creation:")

    zeros = ExBurn.BurnBridge.zeros([2, 3], :f32)

    IO.puts(
      "  zeros([2,3]): shape=#{inspect(ExBurn.Tensor.shape(zeros))}, type=#{ExBurn.Tensor.type(zeros)}"
    )

    ones = ExBurn.BurnBridge.ones([2, 3], :f32)

    IO.puts(
      "  ones([2,3]):  shape=#{inspect(ExBurn.Tensor.shape(ones))}, type=#{ExBurn.Tensor.type(ones)}"
    )

    rand = ExBurn.BurnBridge.rand([2, 3], :f32, 0.0, 1.0)

    IO.puts(
      "  rand([2,3]):  shape=#{inspect(ExBurn.Tensor.shape(rand))}, type=#{ExBurn.Tensor.type(rand)}"
    )

    nx_tensor = Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
    from_nx = ExBurn.BurnBridge.from_nx(nx_tensor)
    IO.puts("  from_nx([[1,2],[3,4]]): shape=#{inspect(ExBurn.Tensor.shape(from_nx))}")

    # 3. Arithmetic
    IO.puts("\nArithmetic:")
    a = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0], [3.0, 4.0]]))
    b = ExBurn.BurnBridge.from_nx(Nx.tensor([[5.0, 6.0], [7.0, 8.0]]))

    add_result = ExBurn.BurnBridge.add(a, b) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  add: #{inspect(add_result)}")

    sub_result = ExBurn.BurnBridge.sub(a, b) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  sub: #{inspect(sub_result)}")

    mul_result = ExBurn.BurnBridge.mul(a, b) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  mul (element-wise): #{inspect(mul_result)}")

    div_result = ExBurn.BurnBridge.div(a, b) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  div (element-wise): #{inspect(div_result)}")

    neg_result = ExBurn.BurnBridge.neg(a) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  neg: #{inspect(neg_result)}")

    abs_result =
      ExBurn.BurnBridge.abs(ExBurn.BurnBridge.neg(a))
      |> ExBurn.BurnBridge.to_nx()
      |> Nx.to_flat_list()

    IO.puts("  abs(neg(a)): #{inspect(abs_result)}")

    # 4. Math functions
    IO.puts("\nMath functions:")
    c = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0, 3.0]]))

    exp_result = ExBurn.BurnBridge.exp(c) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  exp([1,2,3]): #{inspect(exp_result)}")

    log_result = ExBurn.BurnBridge.log(c) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  log([1,2,3]): #{inspect(log_result)}")

    sqrt_result = ExBurn.BurnBridge.sqrt(c) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()
    IO.puts("  sqrt([1,2,3]): #{inspect(sqrt_result)}")

    sigmoid_result =
      ExBurn.BurnBridge.sigmoid(c) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()

    IO.puts("  sigmoid([1,2,3]): #{inspect(sigmoid_result)}")

    relu_result =
      ExBurn.BurnBridge.relu(ExBurn.BurnBridge.from_nx(Nx.tensor([[-1.0, 0.0, 1.0]])))
      |> ExBurn.BurnBridge.to_nx()
      |> Nx.to_flat_list()

    IO.puts("  relu([-1,0,1]): #{inspect(relu_result)}")

    # 5. Linear algebra
    IO.puts("\nLinear algebra:")
    m1 = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0], [3.0, 4.0]]))
    m2 = ExBurn.BurnBridge.from_nx(Nx.tensor([[5.0, 6.0], [7.0, 8.0]]))

    matmul_result =
      ExBurn.BurnBridge.matmul(m1, m2) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()

    IO.puts("  matmul([[1,2],[3,4]], [[5,6],[7,8]]): #{inspect(matmul_result)}")

    transpose_result =
      ExBurn.BurnBridge.transpose(m1) |> ExBurn.BurnBridge.to_nx() |> Nx.to_flat_list()

    IO.puts("  transpose([[1,2],[3,4]]): #{inspect(transpose_result)}")

    v1 = Nx.tensor([1.0, 2.0, 3.0])
    v2 = Nx.tensor([4.0, 5.0, 6.0])
    dot_result = Nx.dot(v1, v2) |> Nx.to_number()
    IO.puts("  dot([1,2,3], [4,5,6]): #{dot_result}")

    # 6. Reductions
    IO.puts("\nReductions:")
    d = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))

    sum_result = ExBurn.BurnBridge.sum(d) |> ExBurn.BurnBridge.to_nx() |> Nx.to_number()
    IO.puts("  sum([[1,2,3],[4,5,6]]): #{sum_result}")

    mean_result = ExBurn.BurnBridge.mean(d) |> ExBurn.BurnBridge.to_nx() |> Nx.to_number()
    IO.puts("  mean([[1,2,3],[4,5,6]]): #{mean_result}")

    # 7. Shape manipulation
    IO.puts("\nShape manipulation:")
    e = ExBurn.BurnBridge.from_nx(Nx.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]))

    reshape_result = ExBurn.BurnBridge.reshape(e, [2, 3])
    IO.puts("  reshape([1..6], [2,3]): shape=#{inspect(ExBurn.Tensor.shape(reshape_result))}")

    softmax_result =
      ExBurn.BurnBridge.softmax(ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0, 3.0]])))
      |> ExBurn.BurnBridge.to_nx()
      |> Nx.to_flat_list()

    IO.puts("  softmax([1,2,3]): #{inspect(softmax_result)}")

    # 8. Loss functions
    IO.puts("\nLoss functions:")
    pred = ExBurn.BurnBridge.from_nx(Nx.tensor([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1]]))
    target = ExBurn.BurnBridge.from_nx(Nx.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]))

    ce_result =
      ExBurn.BurnBridge.cross_entropy(pred, target) |> ExBurn.BurnBridge.to_nx() |> Nx.to_number()

    IO.puts("  cross_entropy: #{Float.round(ce_result, 4)}")

    pred_mse = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0], [3.0, 4.0]]))
    target_mse = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.5, 2.5], [2.5, 3.5]]))

    mse_result =
      ExBurn.BurnBridge.mse(pred_mse, target_mse) |> ExBurn.BurnBridge.to_nx() |> Nx.to_number()

    IO.puts("  mse: #{Float.round(mse_result, 4)}")

    # 9. Nx <-> Burn round-trip
    IO.puts("\nNx <-> Burn round-trip:")
    original = Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
    burn_tensor = ExBurn.BurnBridge.from_nx(original)
    roundtrip = ExBurn.BurnBridge.to_nx(burn_tensor)
    IO.puts("  original:  #{inspect(Nx.to_flat_list(original))}")
    IO.puts("  roundtrip: #{inspect(Nx.to_flat_list(roundtrip))}")
    IO.puts("  match: #{Nx.all(Nx.equal(original, roundtrip)) |> Nx.to_number() == 1}")

    # 10. Tensor inspection
    IO.puts("\nTensor inspection:")
    t = ExBurn.BurnBridge.from_nx(Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]))
    IO.puts("  shape: #{inspect(ExBurn.Tensor.shape(t))}")
    IO.puts("  type:  #{ExBurn.Tensor.type(t)}")
    IO.puts("  rank:  #{ExBurn.Tensor.rank(t)}")
    IO.puts("  numel: #{ExBurn.Tensor.numel(t)}")

    IO.puts("\n=== Done ===")
  end
end

BurnBridgeOps.run()