DrJAX: Scalable and Differentiable MapReduce Primitives in JAX
Abstract
We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.
- Publication:
-
arXiv e-prints
- Pub Date:
- March 2024
- DOI:
- 10.48550/arXiv.2403.07128
- arXiv:
- arXiv:2403.07128
- Bibcode:
- 2024arXiv240307128R
- Keywords:
-
- Computer Science - Distributed;
- Parallel;
- and Cluster Computing;
- Computer Science - Machine Learning