Generating Cats using binned gaussian splats
Table of contents
About
This is a follow up post to the previous one. This time I explore using binned gaussian splats for generating 8x8 image patches instead of learned lookup tables.
Post
The splat kernel is not precisely a gaussian, but a cosine modulated gaussian-like/sinc-like function that oscillates and goes below zero, which helps reduce blur - sort of a wavelet. The function is defined as:

The idea is that a vision transformer consumes 8x8 patches as tokens, and for each patch it outputs parameters for a mixture of 16 splats. Each splat has a mean position within the patch, a RGB color, a precision matrix defining its shape, and a depth value that controls its blending order. The final patch is rendered by evaluating the splat functions at each pixel and summing them up.
Since the splats are differentiable with respect to their parameters, we can train the model end-to-end using the same lerp-to-noise objective as before. The model learns to output splat parameters that reconstruct the original image patches from noisy inputs.
This acts as a denoising target that we un-lerp iteratively from noise, similar to the previous posts. The model architecture is the same as before, a patch transformer with 16 self-attention blocks operating on 64 tokens per image.
Here are some rollouts:



Actually quite surprising that this trains well, and doesn’t hallucinate too much.
The advantage of this approach is that it can be rendered at arbitrary resolutions by simply scaling up the final pixel grid used for rendering of the gaussians, which is not possible with the learned LUTs. This could be useful for generating high-resolution images.

The problem with this approach is that the model has to learn to output reasonable splat mixture for each patch separately. As a result it has visible seams on the patch boundaries, especially when rendering at higher resolution. Increasing the number of splats per patch or training with longer image size might help, but I didn’t explore that further.
Here’s an example rendered with splats normalized to be uniform blobs:

As we can see the model tries to make the distribution such that splats match on the boundaries with one splat dominating at the center.
We can experiment with the kernel to make it more artistic/sharper too. As well as the depth falloff(it needs to not have negative lobs for proper blending). Here’s an example with a sharper kernel(a notch ugly piecewise function):

Results:








This looks quite nice, the sharper kernel helps reduce blur and makes details pop out more. Overall I’m quite happy with how this turned out, it’s a neat way to generate cats. Just need to work on reducing the seams a bit more. Naively generating a global set of splats for the whole image seems to be too slow to converge.
I was able to almost completely eliminate seams by evaluating the splats in a 3x3 region - make patches share splats with their neighbors. This increases computation almost 10x but makes the result much better. It needed a slight retraining to make sure everything is balanse.







