Compile-time computation

Table of contents

Introduction

Programming language designers have explored several avenues for exposing compile-time computation to users. Every time this ability crops up, there's seemingly an equal number in the vocal opposition - as much as there are proponents who enjoy enforcing new invariants or applying custom optimizations.

Hirrolot 2022 provides an excellent discussion concerning the negatives of mainstream designs for compile-time computation:

On the other hand, languages that "do it right" are seemingly praised endlessly, Lisp being a canonical example.

I believe that the reason that Lisp does it right is because Lisp doesn't attempt to coerce a builtin statics system (a type system) into the system for compile-time computation. Languages (at least, the ones which are mainstream now) typically maintain a clear separation between a type system and the runtime system - so when you coerce the type system into your system for compile-time computation, you've exposed a more restricted, less ergonomic language to advanced users1.

What Lisp gets wrong

While Lisp allows arbitrary code execution at compile time via macros, there are aspects of Lisp which leave much to be desired. For my purposes, the main one which comes to mind is a well-documented suite of numerical libraries, and the ability to implement performance optimizations, including the ability to target accelerators. There's a separate issue related to benefits of static typing.

The first is a social/library author issue - and, I suspect would eventually be solved with enough time and funding.

The second is more damning. While I don't doubt that there are efforts to address performance optimization in Lisp implementations - there seems to be compelling evidence that aggressive systems require changes to the language.

From my reading, most large-scale attempts to improve on code generation and compilation for Lisp-like languages (like Julia, and I'm sure I'm missing other, closer to Lisp, projects) attempt to address the issue by adding static systems2.

However, adding these static systems becomes dangerous - especially when they provide a way to interact with the compiler in ways which are ad hoc, compared to the language itself. In Julia's case, generated functions are well-meaning - but early on in Julia's development, a decision was made to allow users to create and return CodeInfo (one of Julia's AST/IR forms) from the generated function staging phase. This allowed the creation of absolute chaos, as well as the void itself in the forms of increasingly speculative compiler packages. Most of Julia's existing problems with automatic differentiation implementations can be traced back to this point - exposing a pre-optimization staging capability which advanced users abused to construct and then market exotic "compile-like" systems. I was one of these users but I've since abandoned the dark for the light (sort of).

The story of generated functions in Julia is one example of how a seemingly well-meaning capability (allowing users to perform some computation and specialize their method code based on types) can have unintended consequences.

On the other hand, I think the direction that Julia has stepped in is the correct one. Ultimately, the problem with Lisp (and duck/untyped dynamic languages) is the performance ceiling. Julia's solution is elegant - the language has dynamic semantics by default, but adds a static system which facilitates performance optimizations. I believe that the problem with Julia is that the step is not big enough - it's just large enough to place Julia in the middle of the statics/dynamics biformity tar pit3. And once you take a step, and reach v1.0.0 - adding additional static systems is difficult - because most static systems are global by default (this may not always be true! In fact, I'm sure it's not in certain dependently typed languages).

What can we learn from hybrid systems like JAX?

Let's consider a compiler system counterpoint to Julia: JAX. Here, our comparison will not truly be strict because Julia allows dynamism and flexibly control flow, while JAX requires a constrained programming discipline in exchange for native compilation and optimization support. In JAX, if you want to compile, you can't opt out of the restrictions.

JAX is a unique system whose features I appreciate immensely. It strikes a compelling balance between practicality, and programming language esoterics.

On the one hand, Python-based deep learning frameworks have almost always erred on the side of pragmatics - elegant PL ideas have taken root in other languages (where experimental AD systems are more likely to find researchers) but JAX is a counterpoint to this trend: it offers a concise PL core based around program tracing and composable interpreter idioms, while also exposing a best-in-class AD system and an array programming model based on NumPy idioms.

The "meta-language" of JAX is Python. Pure, numerical Python programs can be interpreted by JAX, and staged out to a first order, strongly typed array-based IR (called a Jaxpr). By embedding a staging/compilation system into Python, JAX shares much in common with lightweight modular staging, which I discuss below4.

Python can also be used to write interpreters which operate on JAX's IR. These interpreters can also be staged back to the IR! This enables a beautiful way to transform numerical programs - just write an interpreter for the program, and then stage it out.

Now, this idiom is not nearly as simple for object languages which include higher-order concepts. I think this is partially why this problem has not been solved for Julia yet.

If you wish to compile and execute your program natively, the cost of JAX is that you have to program in a restricted model5. It imposes restrictions on allowed control flow, as well as the value types which can appear in lowered (object-level) programs. These restrictions may feel unacceptable to someone accustomed to free usage of iteration, branching control flow patterns where the arity and types of values which flow into and out of each branch can be different, and unbounded recursion.

If you boil these restrictions down and consider possible extensions (including fancy usage of Python to meta-program structured features like JAX Pytree compat sum types) - the core requirement is that JAX has to infer the type and shape of all arrays before lowering to XLA. XLA needs this information to perform aggressive memory layout optimizations. XLA does expose a primitive for unbounded recursion - but it does not allow native dynamic allocation in recursive calls. Getting around this issue would require constructing a custom memory solution - reminiscent of a garbage collector or paging system - for JAX programs which are device compatible.

We can summarize our findings:

To me, JAX provides a big win in the interpreter/staging idiom. It seems easier to express program transformations by writing interpreters, and then staging them out - than the alternative of writing a direct transformation. JAX also wins on engineering: the object language is restricted, so it's easier to ensure that AD works as intended - and the JAX team has put in the hours to make that a reality. There's also a bunch of miscellaneous engineering points which I really appreciate: the Pytree interface between Python datatypes and JAX is genius - and enables automatic struct-of-array representation.

However, again - we run into negatives! Metaprogramming with Python is roughly the same as programming with Python - you can still make stupid mistakes. Now, your stupid mistakes create gigantic stack traces which unravel across both the Python interpreter, and JAX's internals.

Want to know when your code is compat with different modes of AD? The only way to know for sure is to try and trace it with JAX. The same is true for checking whether your code is compat with JIT compilation. The way that JAX implements these checks is suitable ... for now. I can't help but wonder if a clever system based upon co-effect or effect reasoning could improve the ergonomics here. In general, I often wonder if typing a function based upon what capabilities it exposes would allow us to do things like say "this function is reverse mode compat" or "this function is GPU compat".

In addition, JAX does not avoid the statics/dynamics biformity issue. If anything, it throws its hands up at it. There are clearly two languages here: there's the wild west world of Python, with all its dynamism - and then there's the static, marble world of JAX's IR. If we want compilation support, the Python world is forbidden from crossing over into the static world - which can prove frustrating if one wants to express something like dynamic allocation in recursive calls.

What about lightweight modular staging?

Does Zig comptime fit into all of this?

Confession: I don't yet know how to solve this problem

In general, I barely know what I'm talking about - each day I attempt to know a little more6 - but especially in this case: I do not know how to solve this problem for mainstream programming languages.

I have sort of a hunch which I'd like to follow in research. I think the reason that most languages end up in the pit is because they don't step far enough towards dependent types. So they take a step towards the middle of the pit, introduce some reflection / type system features (macros, polymorphism, generics, traits) and introduce a separation - then, later on, use cases arise which require static computation that doesn't easily fit into the interaction between the meta and object languages in the original language design. Attempts to address these use cases create wart systems, increasing the complexity of the language as a whole.

The static meta languages of most mainstream systems are not designed to support or interpret a full featured language at compile-time. Advanced users will ignore this, and then create razor-sharp hybrid compiler packages which can only be maintained by the language authors or advanced users themselves, are not elegant because you have to program in a compile-time language amalgamation7, and can't be checked with the same tools that are used to check runtime behavior.

What sort of language design prevents these questions and warts from arising? Again, my research hunch is that it's some sort of dependently typed language. But dependently typed languages for practical work are quite new and niche - I'm thinking of Lean and F*, even Idris, here. It seems like most dependently typed languages focus on the typechecker, and leave execution or binary emission to be solved by another compiler (e.g. emitting code which GHC will compile). The dream language in my mind is focused on the entire package - with the engineering ergonomics of a cleanly designed modern language like Rust (I just want cargo in every language I use).

Correlated with youth of the effort: my perspective is that dependently typed languages suffer from application notability problems. The group of people who care about dependent types is small to begin with. Then, the group of people who care about verification using Idris (for example) is even smaller. So you have systems like Coq, Agda, Idris, Lean, F*, etc - whose major user groups are basically a tad bit larger than the developers of the systems themselves. They're all co-developing a language system along with a major verification project (or they've co-developed along a few verification projects). Or, they're doing really cool work, but they have to wiggle around the economics of developing these projects by relying on e.g. the hype over crypto.

I'm not equipped to say whether these efforts are ultimately useful - but I will say that I think if a dependently typed language researcher exposed a language with claims like "verified AD transforms, and staging optimizations written in library space that produces code which rivals Lantern, and it hooks into MLIR" - they'd have no problem finding funding, or users, or excited scientific computing practitioners.

Coincidentally, there are research languages with dependent types (András Kovács 2022, eashanhatti 2022) which attempt to provide scaffolding for compile-time computation - while maintaining full dependent types (and consistent languages) across stages. If you poke around the r/ProgrammingLanguages discord - you'll find that research on this frontier is very much alive! The above two researchers (András and Eashan) are two notables from my chats.

My suspicion is that researchers in programming languages of this form aren't likely to have encountered the use cases in machine learning which scientific computing practitioners take for granted. On the other hand, these use cases are compelling!

My research hunch is that there's an appropriately sized hole for a dependently typed language focused on numerical program transformations - which follows from the JAX tradition, but alleviates some of JAX's pain points. I thought this language was going to be Dex but it seems like Dex is missing JAX's critical composable staged interpreters idioms (which make writing transformations so easy in JAX).


1

Advanced users are paradoxically the most likely to write opaque, difficult to decipher compile-time code. They are also the most likely to just shut up, not complain about it, and get work done with it.

2

In Julia, an abstract interpretation based type systems which facilitates inlining and method specialization.

3

Once you're in the tar pit, advanced users will complain because they can't do the compiler-like things they want to do, and basic users will complain because advanced users are creating buggy packages.

6

I'm a hybrid machine learning / programming languages researcher - I care about speed, I care about parallelism, I care about AD, I care about program transformations to support probabilistic programming.

7

@TypingTechnicalInterview is a lovely example.

4

It's actually quite interesting to read Tiark Rompf's early work on LMS (circa 2012) - and then consider the capabilities which JAX exposes from 2018 - present. Tiark is eerily prescient about "the right way" to express systems which rely on program transformations.

5

Roughly: programs which can be lowered to XLA. XLA is a combination of IR dialects (using MLIR) and compiler to implement high-performance optimizations for numerical kernels.