Measuring the effects of data parallelism on neural network training. A great example of careful science in machine learning.
Matthew Johnson
2,009 posts
Joined July 2010
- JAX+NVIDIA at #GTC22! w/ @mjsMLP nvidia.com/gtc/session-ca… New to JAX? This talk gets you up to speed. Already a JAXpert? Check out the new parallelization features at the end of the demo. And hear about how NVIDIA is making JAX faster and more scalable than ever on GPUs!
- Replying to @cHHilleeThey continued: "If I wanted hard-to-install packages and general amateurism, I'd use JAX."
- Replying to @soumithchintala and @JeffDeanYeah! That was one of the first things people (I think @sharadvikram) tried with the JAX codebase, as in Fig 2 of the tech report: storage.googleapis.com/deepmind-media…
- Replying to @fcholletI enjoyed your tweet, and the discussion! It shone with good faith & curiosity. I am eager to see different perspectives, esp yours. (And let's not rule out the possibility that JAX is deficient here... JAX has many weaknesses. We love to hear about them, and to try to improve!)
- Replying to @ankurhandosIn JAX you can set this to be a warning or error: jax.readthedocs.io/en/latest/rank… It can be done with a context manager too. We weren't able to turn it on by default though.
- Replying to @deliprao @RandomlyWalking and @cohenrapI recommend watching Skye's talk at NeurIPS (starts @ 44:26) slideslive.com/38922046/progr… If DL performance alone is too ho-hum, since JAX is about all numerical computing (not just DL) maybe you'd find the NumPyro benchmarks from UberAI interesting: openreview.net/forum?id=H1g1n…
- Replying to @latentjasper and @carlesgeladaSee Section 1.2.2 for a discussion of exactly this complaint about prior specification. Lots of later work too, like the recent paper “Functional Variational Bayesian Neural Networks” from @RogerGrosse ’s group.
- Replying to @jonkhler and @PatrickKidgerYes! github.com/jax-ml/jax-tri… In that repo there’s also an experimental new way to write Triton kernels, called pallas. It uses JAX tracing machinery for a more convenient embedding and some transformability. @sharadvikram
- Replying to @gabrielpeyreThey can be generalized to R[\epsilon] / \epsilon^{k+1} to model higher order autodiff that is faster than nested dual numbers (often exponentially so). See Ch 13 of Griewank and Walther, and jax.experimental.jet in the JAX GitHub repo!
- It doesn't have to be either-or! Let's do more JAX _and_ PyTorch, together! For ex, they can do zero-copy handoff of GPU buffers via DLPack, and AD can be integrated, e.g.: gist.github.com/mattjj/e8b5107… There's much more to be done here. With interop, users win!
- "Functional languages *are* unnatural to use; but so are knives and forks, diplomatic protocols, double-entry bookkeeping, and a host of other things modern civilization has found useful." - James H. Morris bitsavers.org/pdf/xerox/parc…
- Replying to @jm_alexiaIf you have time to share, we would love to hear feedback on our issue tracker and/or GitHub Discussions!
- My sister is a postdoc in biophysics and biochemistry, and she’s writing open-source software for open and reproducible science.If you're looking for fast, transparent, customizable software for analyzing large single molecule FRET/fluorescence data sets, my software package Traces is now compatible with the latest version of Matlab: github.com/stephlj/Traces. Check it out!




