I am a big jax fan and as others have pointed out you can solve your problem with vmap. But furthermore I feel like one of the problems with tensors is that they are just hard to get your head around. If you write out multi-headed attention on a piece of paper with diagrams (but so that it stays clear which dimension does what) it still doesn't look easy. The solution for me is just to describe what each dimension does in each step as a comment.
I am a big jax fan and as others have pointed out you can solve your problem with vmap. But furthermore I feel like one of the problems with tensors is that they are just hard to get your head around. If you write out multi-headed attention on a piece of paper with diagrams (but so that it stays clear which dimension does what) it still doesn't look easy. The solution for me is just to describe what each dimension does in each step as a comment.
jax.vmap definitely helps, but I still find it hard. (Too much axis= stuff). It's the basis for my new thing though!