Back
Running a JAX Program from Dart Using C++ FFI
Learn how to build machine learning apps using JAX, Python, and Dart with efficient XLA compilation and FFI integration.
1. Building machine learning enabled applications
When choosing the programming language and main frameworks for building your application one usually focuses on factors such as performance, development ease, or the ability to deploy across different operating systems and devices. The most popular and renowned of them tend to have most of the desired properties, or propose attractive tradeoffs between them. For example writing your application purely in C++ might make it run faster, but is going to slow down development at the same time. Dart with Flutter on the other hand is slower, but offers solid initial abstraction of aspects of memory management as well as has great support for building local applications across different operating systems.
One thing that most well-established frameworks are missing is robust native support for machine learning (ML) downstream tasks. This is not particularly surprising, since they’ve all been designed before the recent AI surge, but it poses a difficult question for the ecosystem. How do we integrate the ML stack into the application efficiently?
There are currently some ways of achieving this, although none of them is without drawbacks. If the application is deployed locally one of the most popular tools used to run an ML model within it is ONNX Runtime. Having a universal serialization-deserialization scheme it enables ML Engineers to export their models written in popular ML frameworks to a format that then can be loaded into the runtime and called from the developed application. However, one of the downsides of ONNX Runtime is that the produced code representation doesn’t get optimized to the extent that it could, especially when running on a CPU. Additionally, what if we’d also like to include some more generalized algorithms written by our Data Scientists that do not conform to the standard deep learning model archetype? That is going to be quite cumbersome to achieve using this tool.
In this blog we’ll be looking at the JAX library for Python that can aid us with resolving those problems. It enables the team to write heavily optimized code in python that can express any general algorithms that can be written in pure Python and NumPy.
2. What is JAX?
JAX is a python library very similar to NumPy that’s been getting more and more popular amongst machine learning practitioners over the past couple of years. Developed by Google it enables programmers to write platform agnostic code that can be run on CPUs, GPUs, and TPUs without any additional effort. Additionally, JAX has a built-in autograd, JIT, and parallelization support to enable efficient implementation of any SOTA (state-of-the-art) machine learning models and algorithms. JAX code can be executed particularly fast thanks to utilizing an XLA compiler after tracing and translating the user written program to StableHLO.
Ultimately JAX can be seen as a way to give a lot of control to the programmer over how exactly the model is run, that the other popular (deep) machine learning frameworks like PyTorch don’t give by being a lot more “low-level”. Combined with being way more accessible to the average ML scientist (after all it’s “just” supercharged NumPy) it has become very easy to speed up training and inference time of machine learning models using JAX.
To illustrate this, below are two functions using NumPy and JAX respectively performing a simple value assignment to a vector.
Using a Google Colab environment we can evaluate the execution speed of these two functions like so:
Note that we will use .block_until_ready() when running the code on a GPU/TPU to enforce that we await the results from the accelerator for fair speed comparison.
On different platforms we get:
It is important to note, that no changes to the code needed to be made in order for it to be run on these different platforms. This is not the case for common deep learning frameworks which will require the programmer to manually move the model and data onto a device.
We can observe that even for a single simple operation JAX shows to be faster when run on a CPU and a GPU. Notably, the average execution time on the GPU and TPU includes moving data to and from the accelerator, which substantially contributes to the overall time-performance. In non-trivial settings, where the executed program is longer the speed-up can be even more noticeable. If we look at the TPU time we see that the over-the-wire cost dwarfs anything else - the takeaway here is that it only becomes useful to use a TPU when the model is sufficiently large in the first place.
3. Incorporating JAX code into prod
We’ve established that using JAX can significantly speed up our training and inference time. How could we utilize this in a real world application? One way would be to create a containerized microservice that’s running a Python interpreter that can execute our code. That’s all well and good, and certainly a valid option for applications that are utilizing cloud hosted compute.
But what if we needed to include and ship our algorithms in a locally run application? It is usually not common practice to deploy a Python interpreter within a product. To this end we can make use of the fact that JAX compiles down with XLA which is a common backend available in other programming languages as well. The plan is to write our code comfortably in Python using JAX and then deploy and run it from within a compiled C++ library that can be easily incorporated into any project.
4. How to generate an HLO proto for XLA
The first thing we’ll need to do (after writing our algorithm in Python using JAX naturally) is generate a specification of our JAX program that can be used as an input to an XLA compiler. Importantly we should make sure that our top-level JAX function that we’ll be calling can be transformed with jax.jit or any other transformation that forces a compilation like jax.pmap.
This is the example code we’ll be using for this demonstration.
Make sure that the code is accessible as a package from Python. For this I have put it in a file called prog.py, created a simple library called jax_example for myself and installed it, such that I can now import it with `from jax_example.prog import fn`. Information on how to do this can be found online, as we will not be going over creating Python libraries in this blog.
Normally when we’re just running everything from a Python script what happens is that:
Whenever the compilation is triggered - either manually ahead-of-time by calling .lower() on our transformed code, or just-in-time by just simply running any data through it, i.e. calling the function - the function is traced and an optimized specification is produced. Note that at this stage the shapes and types of the function arguments are getting baked in.
The specification is used to compile a program using XLA for a given platform. Whenever we rerun our code later this gets reused.
The compiled code gets called and we get the results back. Perhaps slightly confusingly print statements in our original code are not persisted in the trace and therefore they are not present in the compiled version of the code. This means we’ll not be able to see any prints from our code after the initial one-time tracing.
What we’re doing here will split this process across the different parts of our flow:
Staging out the specialization is going to happen once ahead of time during development. The resulting trace will become an asset that gets shipped with the application.
The compilation happens on the target device just-in-time in the lifecycle of the application process itself. This could theoretically be done ahead-of-time if we knew the exact specification of the target machine which is not realistic in this scenario. This will be handled by the C++ library.
The compiled code run within the C++ library gets called through an FFI from Dart.
To achieve the first goal, namely creating the specification, we can use a tool script provided by JAX itself. We can clone the official JAX repository and inspect jax/tools/jax_to_ir.py. We will need to run:
This will create the code specification files. For simplicity’s sake we’ve put them in a globally accessible path. Ideally those files would be placed in the application’s assets folder, as we’ll need to ship them with it. We can inspect the fn_hlo.txt file:
5. Creating a C++ dylib that uses TensorFlow XLA backend
We will now create a C++ dylib that will use the files generated earlier and compile them with XLA. For this we can continue to utilize the JAX repository we’ve cloned earlier. We will want to head over to jax/examples/jax_cpp and make a few modifications.
We will want to add a header file called `main.h`:
Next we will want to modify the BUILD file such that it builds a dynamic library for us:
Lastly, we’ll want to add these two imports to our `main.cc`:
We’ll need to make sure that Bazel is installed on our system. It’s a very convenient build manager that handles fetching libraries and recompilation of specific sources on code changes without rebuilding the entire solution. It can be installed following the official guides here.
After that is done from the root jax repository directory we can run:
This will build our target called `jax` and produce artifacts in jax/bazel-bin/examples/jax_cpp. You should see a file called libjax.dylib in that directory.
6. Using Dart FFI to communicate with the C++ code
Now that we have the C++ library ready we can load it in and use it from any environment we want. In this example, however, we will be using Dart’s FFI.
Inside our project we will create a file called `jax.dart` and copy the code just below.
The DynamicLibrary class comes from the ffi package and will allow us to communicate with the C++ code. Importantly the loadDynamicLibrary method will look for the library on the root path of your project by default. We simply copy over the compiled library and paste into the same directory that has our `pubspec.yaml`. The `.lookupFunction` method requires us to define the input and output types that will be used to communicate between Dart and C++. This is a simple case where the input and output are just integers.
Since we’re using a package dynamic_libraries to load in our dylib we’ll also need to include in in our `pubspec.yaml`:
To test out our code we can create a unit test that will call the `.bar` method of the JAX class. Since we’re not actually returning anything to Dart at this point and only print out the results we should see output similar to this in the console:
7. Next Steps
We’ve demonstrated that any arbitrary (jittable) JAX code can be executed in C++ from Dart. What’s left is to write more complicated algorithms in JAX that we want optimized and run on an accelerator. Apart from the obvious machine learning applications, some other examples include but are not limited to search (eg. A*), combinatorial optimization algorithms (scheduling, partition), or even classical image processing methods like finding edges with a Canny filter! An additional benefit of using JAX is that those algorithms can be written by Data Scientists and Machine Learning Engineers directly in Python without the need to involve an experienced C/C++ programmer to rewrite them for production applications. This can save a lot of development time and headaches since resolving C++ bugs and memory leaks can be very troublesome.