Introducing NABLA

Today we are releasing a research preview of NABLA - a framework for differentiable programming in Mojo. Nabla aims to bring to Mojo what parts of JAX and PyTorch brought to Python: a high-level API for general program transformations, including vmap, jit, vjp, jvp & grad.

Unlike previous attempts (e.g. Endia) that failed by attempting to rebuild the entire stack, Nabla was built from the ground up as a thin wrapper around Mojo and MAX to provide the same performance guarantees as them. (The Nabla core does NOT include any low-level kernels, for example.) There are many rough edges and features that still need to be implemented (operator coverage, GPU support, etc.), but the core AD engine has proven effective in initial tests. We hope you like it!

Docs and homepage: nablaml.com

Examples and roadmap : github.com/nabla-ml/nabla

Follow us on X for related posts : @nablaml

20 Likes

This is so cool @TilliFe, congratulations!

3 Likes
import nabla

def main():
  # Init params with gradient computation enabled
  weight = nabla.randn((3, 4), DType.float32, requires_grad=True)
  bias = nabla.randn((2, 4), DType.float32, requires_grad=True)
  label = nabla.randn((2, 4), DType.float32)
  input = nabla.randn((2, 3), DType.float32)

  # Compute forward pass (single layer MLP)
  logits = nabla.relu(input @ weight + bias)
  loss =  nabla.sum((logits - label) ** 2)
  print("Loss:", loss)

  # Backward pass to compute gradients
  loss.backward()

  # Update parameters à la SGD
  weight -= 0.01 * weight.grad()
  bias -= 0.01 * bias.grad()
  print("weight:", weight, "bias:", bias)

I love this example you have in the README. Its literally the textbook description of a neural net and backprop in the minimum number of lines of code such a thing could be described.

:clap: :clap: :clap:

2 Likes