Agreed, JAX is specialized for GPU computation -- I'd really like similar capabilities with more permissive constructs, maybe even co-effect tagging of pieces of code (which part goes on GPU, which part goes on CPU), etc.
I've thought about extending JAX with custom primitives and a custom lowering process to support constructs which work on CPU (but don't work on GPU) -- but if I did that, and wanted a nice programmable substrate -- I'd need to define my own version of abstract tracing (because necessarily, permissive CPU constructs might imply array type permissiveness like dynamic shapes, etc).
You start heading towards something that looks like Julia -- the problem (for my work) with Julia is that it doesn't support composable transformations like JAX does.
Julia + JAX might be offered as a solution -- but it's quite unsatisfying to me.