Running a deep learning workload with JAX on multinode multi-GPU clusters on OCI

January 11, 2023 | 6 minute read
Sanjay Basu PhD
Senior Director - Gen AI/GPU Cloud Engineering
Text Size 100%:

JAX is a rapidly growing Python library for high-performance numerical computing and machine learning (ML) research. With applications in drug discovery, physics ML, reinforcement learning and neural graphics, JAX has seen incredible adoption in the past few years. JAX offers numerous benefits for developers and researchers, including an easy-to-use NumPy API and direct integration with Autograd for easy differentiation and optimization. JAX also includes support for distributed processing across multinode and multi-GPU systems in a few lines of code, with accelerated performance through XLA-optimized kernels on NVIDIA GPUs.

To help more customers to take advantage of the power of JAX, Oracle and NVIDIA are working together to enable easy setup and use of JAX on multinode clusters in OCI. With the combination of Oracle Cloud Computing platform, NVIDIA GPUs and RDMA networking, you can take advantage of the scalability and performance benefits of JAX in a cloud environment.

In this guide, we walk through how to set up a multinode high-performance comuting (HPC) cluster powered by NVIDIA A100 Tensor Core GPUs on OCI with built-in support for CUDA and SLURM. We also go over how to install JAX into that environment and get started with multinode JAX code.

Import a GPU + OFED image into OCI

First, we import a machine image into your compartment that includes support for both GPU computation with CUDA and RDMA over ethernet using OFED.

  1. In the Oracle Cloud Console, go to the menu and, in the Compute section, select Instances.

  2. In the side panel, select “Custom Images” and then “Import image.”

  3. Select “Import from object storage URL” and “OCI.”

  4. Use the following URL to import the image:

The image is accessible on request with current OCI subscriptions.

Create a multinode cluster with NVIDIA A100 GPUs in OCI

  1. Log in to the Oracle Cloud Console.

  2. In the Marketplace section of the menu, select “All applications.”

  3. Search for and select “HPC cluster.”

  4. Choose the latest version and your OCI compartment. Then click Launch stack.

  5. Optionally, configure your cluster name, description, and tags. Then, click Next.

  6. For your cluster configuration, use the following settings:

    • Upload your SSH public key. If you don’t have SSH keys, you can generate a public or private key pair using ssh-keygen and following the prompts. For more information, such as generating keys or Windows, see Generate SSH Keys.

    • For the head node, select an availability domain, and leave the shape as default. Increase the disk size to 500 GB.

    • For the Compute nodes, select an AD availability, choose the “BM.GPU4.8” shape (8x NVIDIA A100 40GB Tensor Core GPUs), uncheck hyperthreading, increase the boot disk to 500 GB, select 4 as “Initial cluster size” for four nodes, uncheck “use marketplace image," and select the image you imported previously.

A screenshot of the Headnode Options screen.

A screenshot of the Compute Node Options screen with relevant sections outlined in red.

Leave the other options at their default and click Next. On the review page, check “Run apply” and then click Create.

  1. In the menu, go to “Resource Manager” and then “Stacks.”

  2. Select the “Stack detail” and then “Job details.”

  3. Wait for the apply job to finish and succeed. This process can take 15–45 minutes.

A screenshot of the Job Details page with the status outlined in red.

SSH into the head node and install Python and JAX

  1. In the menu, go to Compute and then Instance.

  2. Locate the public IP address of the head node.

  3. Use the IP address to SSH into the head node with the following command:

    ssh opc@[IP_ADDRESS]
  4. Install Miniconda and Python by running the following command:

    wget -O ~/ && bash ~/ -b -p ${HOME}/miniconda && ~/miniconda/bin/conda init bash
  5. Log out and log back in. Then install Jax with the following command:

    pip3 install --upgrade "jax[cuda]" -f

Run your first JAX code with SLURM

OCI HPC clusters come preinstalled with SLURM, which simplifies running multiprocess jobs across multiple nodes. To run JAX with SLURM, create a file ~/ with the following content:

import jax

import jax.numpy as jnp


print(f"Total devices: {jax.device_count()}, "

      f"Devices per task: {jax.local_device_count()}")

x = jnp.ones(jax.local_device_count())

# Computes a reduction (sum) across all devices of x

# and broadcast the result, in y, to all devices.

# If x=[1] on all devices and we have 32 devices,

# the result is y=[32] on all devices.

y = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(x)


From the login node, run the following command:

(base) [opc@helpful-python-bastion ~]$ srun -N 4 -n 32 --tasks-per-node=8 --gpus-per-node=8 bash -c "~/miniconda/bin/python3 ~/"

You get receive the following output:

Total devices: 32, Devices per task: 1



Total devices: 32, Devices per task: 1


Congratulations! You have successfully run a JAX job on four nodes with 32 NVIDIA A100 GPUs and 32 processes in total. With this guide, you’re well-equipped to get started on your multinode JAX journey. For more information on how JAX multihost processing works, visit the JAX documentation. For more information and the latest updates for JAX GPU workloads, sign up for NVIDIA’s JAX Early Access program, where you can get an early look at how JAX can run faster and at larger scales on GPUs.

Getting started

The addition of the NVIDIA A100 80GB GPU complements the existing line up of available NVIDIA GPUs on OCI including the just announced A10 GPU, opening a new era of accelerated computing for startups, enterprises, and governments around world on OCI. Not all workloads are the same, and some might require customization to work optimally on the latest generation GPU hardware. OCI offers technical support to get your workload up and running, so talk to your sales contact for more information.

For more information about BM.GPU.GM4.8 instances including specifications, pricing, and availability, see our documentation. To learn more about Oracle Cloud Infrastructure’s capabilities, explore the following resources:


Leopold Cambier, software engineer

Leopold Cambier received his Ph.D. in Computational and Mathematical Engineering from Stanford University in 2021. During his Ph.D., he focused on fast solvers for very large sparse linear systems, from both a theoretical and parallel-computing perspective. He also interned at NVIDIA in 2016 and 2017 when he worked on cuDNN. Since he joined full-time in January 2021, Leopold has been working on cuFFT and cuFFTMp in the CUDA Math Libraries team, and more recently on JAX.

Neal Vaidya, technical marketing engineer

Neal Vaidya is a technical marketing engineer for deep learning software at NVIDIA. He is responsible for developing and presenting developer-focused content on deep learning frameworks and inference solutions. He holds a bachelor’s degree in statistics from Duke University.

Sanjay Basu PhD

Senior Director - Gen AI/GPU Cloud Engineering

Sanjay focuses on the advanced services like Generative AI, Machine-Learning, GPU Engineering, Blockchain, Microservices, Industrial IoT, 5G core along with Cloud Security and Compliance. He has double masters in Computer Science and Systems Design. His PhD was in Organizational Behaviour and Applied Neuroscience. Currently, he is pursuing his second PhD in AI. His focus of research is Retentive Networks.

Previous Post

Latest OCI Blockchain Platform update enables blockchain interoperability and brings Web3 capabilities to OCI

Mark Rakhmilevich | 13 min read

Next Post

ODSA disaster recovery best practices: Exadata Database and Base Database services

Andrea Marchesini | 10 min read