Skip to content

juju, from JAX to MAX.

juju is a bit of compiler middleware bridging (parts of) JAX to the world of MAX graphs. It allows:

  • users to write JAX programs (see caveat below), lower those programs to MAX graphs, and execute those graphs on MAX-supported hardware, including CPUs, GPUs, and (later on), xPUs (whatever MAX supports).
  • users to extend the Python JAX language (the primitives that JAX exposes to write numerical programs) with custom MAX kernels.

Danger, Will Robinson!

This package is a proof-of-concept, and very early in development. Simple programs only for now! It's not yet clear how much of JAX will be fully supported (and how many extensions via MAX kernels will be added).

JAX is a massive project, with tons of functionality! It's unlikely that this package will ever support all of JAX (all JAX primitives, and device semantics). The goal is to support enough JAX to be dangerous, and to provide ways to easily extend the functionality of this package to support e.g. more of JAX, or to plug your own custom operations to define your own JAX-like language with compilation to MAX.

Example:

Using juju to transform and execute code with MAX.

import jax.numpy as jnp
from juju import jit

@jit
def jax_code(x, y):
    v = x + y
    v = v * v
    return jnp.sin(v)

print(jax_code(5, 10).to_numpy())
-0.93009484

Getting started

To get started with juju, you'll need to follow these steps:

  • First, install magic, the package and compiler manager for MAX and Mojo.
  • Then, clone this repository, and run magic install at the toplevel. This will setup your environment, which you can access via magic shell.
  • Then, run magic run kernels to build the custom MAX kernels provided as part of juju.

Basic APIs

To start out, let's examine basic APIs which allow you to execute functions using MAX, and create MAX graphs.

juju.jit

jit(
    f: Optional[Callable[..., any]] = None,
    coerces_to_jnp: bool = False,
    engine: JITEngine = cpu_engine(),
)

Returns a function which JIT compiles the provided function using MAX by first creating a MAX graph, loading it into the MAX engine, and then executing it.

The first invocation of the JIT'd function will be slow to compile, but subsequent invocations will be fast, as the graph is cached by MAX, and juju stores a callable function which avoids repeating the lowering process.

Example:

import jax.numpy as jnp
from juju import jit


@jit
def foo(x):
    return x * x


print(foo(5).to_numpy())
25

Automatic conversion of MAX tensors to JAX arrays

juju.jit supports an option called coerces_to_jnp which can be used to automatically convert MAX tensors to JAX numpy arrays. By default, this option is set to False.

import jax.numpy as jnp
from juju import jit


@jit(coerces_to_jnp=True)
def foo(x):
    return x * x


print(foo(5))
25

Customizing the target platform

If you have a GPU available, you can execute the code by using the gpu_engine function to create a JIT engine that uses the GPU.

import jax.numpy as jnp
from juju import jit, gpu_engine


@jit(engine=gpu_engine())
def foo(x):
    return x * x
Source code in src/juju/compiler.py
def jit(
    f: Optional[Callable[..., any]] = None,
    coerces_to_jnp: bool = False,
    engine: JITEngine = cpu_engine(),
):
    """
    Returns a function which JIT compiles the provided function using MAX by first creating a MAX graph,
    loading it into the MAX engine, and then executing it.

    The first invocation of the JIT'd function will be slow to compile,
    but subsequent invocations will be fast, as the graph is cached by MAX,
    and `juju` stores a callable function which avoids repeating
    the lowering process.

    **Example:**

    ```python exec="on" source="material-block"
    import jax.numpy as jnp
    from juju import jit


    @jit
    def foo(x):
        return x * x


    print(foo(5).to_numpy())
    ```

    **Automatic conversion of MAX tensors to JAX arrays**

    `juju.jit` supports an option called `coerces_to_jnp`
    which can be used to automatically convert MAX tensors
    to JAX numpy arrays. By default, this option is set to `False`.

    ```python exec="on" source="material-block"
    import jax.numpy as jnp
    from juju import jit


    @jit(coerces_to_jnp=True)
    def foo(x):
        return x * x


    print(foo(5))
    ```

    **Customizing the target platform**

    If you have a GPU available, you can execute the code by using the
    `gpu_engine` function to create a JIT engine that uses the GPU.

    ```python
    import jax.numpy as jnp
    from juju import jit, gpu_engine


    @jit(engine=gpu_engine())
    def foo(x):
        return x * x
    ```
    """
    if f is None:
        return functools.partial(
            jit,
            coerces_to_jnp=coerces_to_jnp,
            engine=engine,
        )
    return JITFunction(f, coerces_to_jnp, engine)

juju.make_max_graph

make_max_graph(
    f: Callable[..., Any],
) -> Callable[..., Graph]

Returns a function that constructs and returns a MAX graph for the provided function using JAX tracing.

Example:

import jax.numpy as jnp
from juju import make_max_graph


@make_max_graph
def foo(x):
    return x * x


print(foo(5))
mo.graph @foo(%arg0: !mo.tensor<[], si32>) -> !mo.tensor<[], si32> attributes {argument_names = ["input0"], result_names = ["output0"]} { %0 = mo.chain.create() %1 = rmo.mul(%arg0, %arg0) : (!mo.tensor<[], si32>, !mo.tensor<[], si32>) -> !mo.tensor<[], si32> mo.output %1 : !mo.tensor<[], si32> }
Source code in src/juju/compiler.py
def make_max_graph(f: Callable[..., Any]) -> Callable[..., Graph]:
    """
    Returns a function that constructs and returns a MAX graph
    for the provided function using JAX tracing.

    **Example:**

    ```python exec="on" source="material-block"
    import jax.numpy as jnp
    from juju import make_max_graph


    @make_max_graph
    def foo(x):
        return x * x


    print(foo(5))
    ```
    """

    @functools.wraps(f)
    def wrapped(*args):
        _, graph = _max(f)(*args)
        return graph

    return wrapped

Custom operations and primitives

A very nice feature of MAX is that the operation set is extensible, and the language for authoring operations is Mojo, a language with high-level ergonomics (compared to CUDA, for instance).

As a result, extending the operation set with new GPU computations is much more approachable than extending XLA with custom CUDA computations, and can be performed without leaving the juju project or introducing external compilers (besides the Mojo compiler, which is accessed via magic).

There are two steps to exposing custom operations to juju:

  • Writing a MAX kernel using Mojo.
  • Exposing the kernel to MAX, and providing the necessary information to JAX in the form of a new Primitive.

Writing a MAX kernel

A MAX kernel takes the form of a Mojo source code file. The MAX development team has kindly shared several of these kernels for study. Additionally, this article is worth reading to gain a general understanding of custom operations.

Let's examine a kernel, and imagine that we've placed this into a folder called kernels/add_one.mojo:

kernels/add_one.mojo
import compiler
from utils.index import IndexList
from max.tensor import ManagedTensorSlice, foreach
from runtime.asyncrt import MojoCallContextPtr


@compiler.register("add_one", num_dps_outputs=1)
struct AddOneCustom:
    @staticmethod
    fn execute[
        # Parameter that if true, runs kernel synchronously in runtime
        synchronous: Bool,
        # e.g. "CUDA" or "CPU"
        target: StringLiteral,
    ](
        # as num_dps_outputs=1, the first argument is the "output"
        out: ManagedTensorSlice,
        # starting here are the list of inputs
        x: ManagedTensorSlice[out.type, out.rank],
        # the context is needed for some GPU calls
        ctx: MojoCallContextPtr,
    ):
        @parameter
        @always_inline
        fn func[width: Int](idx: IndexList[x.rank]) -> SIMD[x.type, width]:
            return x.load[width](idx) + 1

        foreach[func, synchronous, target](out, ctx)

    # You only need to implement this if you do not manually annotate
    # output shapes in the graph.
    @staticmethod
    fn shape(
        x: ManagedTensorSlice,
    ) raises -> IndexList[x.rank]:
        raise "NotImplemented"

Kernels are Mojo structures that are decorated with @compiler.register, and they contain a method called execute which contains the execution semantics of the kernel.

To expose the kernel as a MAX operation, the kernel needs to be placed into a Mojo package -- meaning we need a kernels/__init__.mojo:

kernels/__init__.mojo
from .add_one import *

We can then ask mojo to compile the Mojo package into a kernels.mojopkg, which we can then use via MAX's Python API to give MAX access to the kernels:

mojo package kernels -o kernels.mojopkg

Keep your kernels package up to date!

When implementing custom operations, make sure that the kernels package you're using is up-to-date! Otherwise, during graph loading, MAX will complain about being unable to find your kernel.

In the Python API, we can give access to the kernels by providing a custom_extensions argument to engine.InferenceSession:

from max import engine 

engine.InferenceSession(
    custom_extensions="./kernels.mojopkg",
)

This is exactly how juju does this under the hood, and examining the code should provide further details.

Exposing the kernel to JAX

Now, MAX is only one side of the coin. The other side is that we'd like to incorporate these computations in JAX source code.

JAX allows users to extend JAX's program representations (the Jaxpr) by introducing new primitives, units of computation that accept and return arrays.

Interim on the juju pipeline

juju plugs into JAX in the following way:

  • (Tracing) First, we use JAX to trace Python computations to produce Jaxprs.
  • (Lowering) Then, juju processes these Jaxprs with an interpreter to create MAX graphs.

Let's say we want to introduce a new primitive to JAX. The first tracing stage requires that the primitive communicate with JAX about the shapes and dtypes of the arrays it accepts as input, as well as the shapes and dtypes of the arrays it produces as output. As long as we tell JAX this information, it doesn't care about "what the primitive does". We'll call this information a jax_abstract_evaluation_rule.

The second lowering stage requires that we tell the juju interpreter how the primitive is going to be represented in the MAX graph. We'll call this information a max_lowering_rule.

To aid in the effort of coordination between JAX and MAX, juju exposes a function called juju.Primitive:

juju.Primitive

Primitive(
    name: str,
    max_lowering_rule: Callable,
    jax_abstract_evaluation_rule: Callable,
    multiple_results=True,
)

Construct a new JAX primitive, and register jax_abstract_evaluation_rule as the abstract evaluation rule for the primitive for JAX, and max_lowering_rule for juju's lowering interpreter.

Returns a function that invokes the primitive via JAX's Primitive.bind method.

Source code in src/juju/primitive.py
def Primitive(
    name: str,
    max_lowering_rule: Callable,
    jax_abstract_evaluation_rule: Callable,
    multiple_results=True,
):
    """
    Construct a new JAX primitive, and register `jax_abstract_evaluation_rule`
    as the abstract evaluation rule for the primitive for JAX, and `max_lowering_rule` for `juju`'s lowering interpreter.

    Returns a function that invokes the primitive via JAX's `Primitive.bind` method.
    """
    new_prim = JPrim(name)
    new_prim.def_abstract_eval(jax_abstract_evaluation_rule)
    max_rules.register(new_prim, max_lowering_rule)

    # JAX can't execute the code by itself!
    # We have to use MAX, so we raise an exception if JAX tries to evaluate the primitive.
    def _raise_impl(*args, **params):
        raise Exception(f"{name} is a MAX primitive, cannot be evaluated by JAX.")

    new_prim.def_impl(_raise_impl)

    def _invoke(*args, **params):
        return new_prim.bind(*args, **params)

    return _invoke

For instance, to use our add_one kernel, one would use the following patterns:

using_our_prim.py
from juju import Primitive, jit
from jax.core import ShapedArray
import jax.numpy as jnp
from max.graph import ops, TensorType

# Lowering rule to MAX, gets called by 
# juju's lowering interpreter.
def add_one_lowering(x, **params):
    return ops.custom(
        name="add_one", # needs to match your @compiler.register name
        values=[x],
        out_types=[TensorType(dtype=x.dtype, shape=x.tensor.shape)],
    )[0]

# Abstract evaluation rule for JAX, gets called
# by JAX when tracing a program to a Jaxpr.
def add_one_abstract(x, **params):
    return ShapedArray(x.shape, x.dtype)

# Register and coordinate everything, get a callable back.
add_one = Primitive(
    "add_one", # can be anything
    add_one_lowering, 
    add_one_abstract,
)

@jit
def jaxable_program(x):
    x = x * 2
    return add_one(x) # use the callable

# Execute your program using MAX.
print(jaxable_program(jnp.ones(10)).to_numpy())
[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]

The point being that juju.Primitive acts as a very convenient glue between JAX and MAX.