The chronicles of training FBPINNs
Right now I am training FBPINNs on IITKs multicore multiprocessor CPU and GPU clusters, and am frustrated by the fact that JAX does not automatically parallelize operations across CPU cores unless you explicitly use parallelism constructs like pmap
or shard_map
.
JAX's CPU backend is not optimized for maximum CPU utilization. Both of those factors mean that JAX's CPU backend, while faster than using Numpy, is slower that single core C++ and significantly slower than parallel C++ which makes this backend currently unsuitable for high performance use-cases.
It seems to have some Intra-op parallelism - some operations (like BLAS/LAPACK matrix operations) can use multiple threads internally, is available. However not all operations have internal multi threading, for example the FFT operations do not have multi-threading enabled. JAX dosent do Inter-op parallelism (running different operations in parallel across cores) automatically.
JAX relies on underlying BLAS libraries (like OpenBLAS, MKL, or Eigen) for linear algebra operations. These libraries have their own threading mechanisms that can be controlled via environment variables :
# These control BLAS threading
export OMP_NUM_THREADS=32
export MKL_NUM_THREADS=32
export OPENBLAS_NUM_THREADS=32
Explicit parallelism via pmap :
import os
# Must set before importing JAX
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=32'
import jax
from jax import pmap
# Use pmap to distribute work across CPU cores
@pmap
def parallel_computation(x):
# Your computation here
return some_function(x)
pmap cpu multithreading issue conveys that it's first necessary to split the CPU into multiple apparent devices using the --xla_force_host_platform_device_count
flag.
I will have to dig deeper to figure out CPU/GPU parallelism improvements with JAX. It's fun to train such large ML models on many GPUs.
I am finding ML Infrastructure and overall Software Engineering much harder than the actual ML itself. What a joke!