I think another easy improvement to this diffusion model would be for the logprobs to also affect the chance of a token being turned into a mask. So higher confidence tokens should have less of a chance to be pruned, should converge faster. I wonder if backprop would be able exploit that. (I'm not an ML engineer).