Saturday, June 8, 2024

A postmortem on failing to adapt FSE to GPU compute

I think failures are educational and worth documenting, so here's the story of how FSE-on-GPU failed.

One of the things I'm working on right now is Gstd, an attempt to make a Zstandard variant that runs on GPU compute.

Zstandard uses a newer entropy coding method called FSE, which is a variant of tabled asymmetric numeral systems (tANS).  Using this scheme requires two steps, both of which have challenges in how to implement them on GPU compute: Decoding the probabilities and converting those probabilities into a decode table.

GPUs are designed for highly parallel instruction execution, especially executing the same operation on many different inputs at once (a "vector" of values).  Like Brotli-G and GDEFLATE, Gstd uses multiple bitstreams simultaneously to allow values to be parsed from all of them at once.

One of the reasons this type of scheme is so efficient is that GPUs have instructions to compute a running total for each of the inputs, so if you can create a vector containing 0 for lanes that don't need to be refilled and 1 for lanes that do need to be refilled, and do a prefix-sum op, then you'll get a vector where all of the lanes contain the offset from the current read position that they need to be refilled from.

Doing things in parallel this way however requires minimizing the dependency of sequential values.  Running totals ARE dependent on preceding values, but they can be done fast because the GPU has an operation to do it.

This creates a problem for decoding FSE because it has multiple steps that are sequentially dependent.

Here's a quick summary of the FSE table decoding and generation algorithm:

  • Determine how many table slots have yet to be assigned.
  • Determine the maximum possible value that the next value to read can have - It is equal to the number of remaining slots + 1.
  • Load enough bits to encode that value.
  • If the value is small enough (how small is determined by the number of remaining slots), return 1 bit to the bitstream (or, depending on how you phrase it, don't load it in the first place).  This step prevents any value from being encoded that would be larger than the number of remaining slots.
  • Subtract 1 from the decoded number
  • If the resulting number is -1, then it is a special "less than one" probability with a slot usage of 1.  Otherwise the slot usage is the decoded value.
  • Repeat the preceding steps until there are no slots remaining.
  • Once all symbols are decoded, position all "less than one" probabilities at the end.
  • For each symbol, scatter the slots by adding a prime number to the last slot index, and then take that sum modulo the table size.  If it overlaps a less-than-one entry, repeat this process until an empty slot is found.
  • Sort the table positions in ascending order.
  • Assign each slot a baseline and bit usage.  Basically, when a value is decoded, the slot corresponding to a "state" value is read from the table.  That slot indicates a number of bits to load and a baseline value that it must be added to.

The example given in the specification looks like this:

state order 0 1 2 3 4
state value 1 39 77 84 122
width 32 32 32 16 16
Number_of_Bits 5 5 5 4 4
range number 2 4 6 0 1
Baseline 32 64 96 0 16
range 32-63 64-95 96-127 0-15 16-31

Note that the baselines start at zero where bit usage is low and wrap around to the high bit usage ones, which are lower-numbered and come first.  The size of the baseline range is always 2^Number_Of_Bits, which ensures there is always a slot that will decode to any given state value.  Like all ANS variants, these are encoded in the reverse order that they are decoded, so there needs to always be a slot that will decode to a symbol that, combined with the bits, equals a the desired state value.

There are several serial dependencies here:

  • The low-bit-usage cutoff of each probability depends on the slot usage remaining after the previous symbol.
  • The table index depends on how many less-than-one probabilities were skipped by preceding symbols (otherwise it is just a multiple of the slot index in the series).
  • The bit ranges depend on how many slots of the same symbol precede it in numerical order.

The last one has a few implications, there are basically two ways to do it: Run a loop over the post-placement slots and figure out which occurrence of the symbol each entry is, or run a loop over pre-placement slots for each symbol to sort them.  Both are seriously problematic due to low utilization.  What if we could just assign the baselines to pre-scattered values instead of post-scatter?

This goes against the theory, but if it didn't hurt the compression ratio enough, it might make it work.  A clever idea that turned out to be a big mistake.  It's actually pretty nice in terms of implementation possibilities:

First, fill an empty table with zeroes and then pack the baseline, bits, and symbol into a value and put it into the table.  On hardware with WaveMatch and equivalent ops, it's possible to create a bitfield containing 0 for lanes with zeroes and 1 for lanes with non-zero values, which can be used in tandem with masking and high-bit search ops to quickly find which lane preceding a given lane has a non-zero value.  On hardware without it, it's possible to do it in logarithmic time by packing and transforming the values in a way that their desired order is numerically ascending and repeatedly using WaveMax.

The baselines are ascending, but were monotonically ascending, so this basically involved negating the bit usage and coding a value that could be subtracted from the slot index and then multiplied by 2^BitUsage to get the actual baseline.

Once that's all done, loop over that table and multiply each index by the prime distribution value and take that value modulo the table size to get the final position.  Great!


So, I tested this out with some sample data from OpenArena.  The results were coming out pretty favorably (~15%) ahead of GDEFLATE and changing from post-scatter to pre-scatter was not having a big impact on the compression ratio.

However, FSE is still difficult to work with on GPU: It requires a big table, and that big table has to be indexed from groupshared memory because of how big it is, unless it can be crammed into a register somehow... But there's a lot of data that would have to be crammed into registers!

For 15%, it would be worth it though!  Unfortunately that hit a painful realization while working on the decoder: The encoder was dropping the contents of in-compressible blocks.  Once that was fixed, the savings were only about 3%.

3% was not great to begin with, but was especially odd because GDEFLATE is not functionally different from Deflate, and Zstandard usually outperforms Deflate by a measurable margin.  Upon running tests, there were a good number of blocks where Zstandard was compressing to about 10% larger than GDEFLATE.  My best guess was, and still is, that GDEFLATE uses libdeflate as its compressor, and libdeflate is quite good... so, to test this possibility, I needed to make some code that could convert a Deflate block into a Zstandard block!

Now, in order to do that, I need to generate FSE probability tables for the match, literal length, and offset codes, and then rescale them to total a power of 2 as the scheme requires.  How do you do that?

Well, normally for an arithmetic or range coder, the cost per occurrence of a symbol is -log(Probability/SumOfProbabilities)/log(2).  As probability decreases, the bit usage declines rapidly at first and then slower.  Through some math, this can be simplified into smaller calculations, but there's one thing that's weird about FSE: The reduction of bit cost seems linear, doesn't it?  Maybe the bit usage is just approximately-logarithmic but not actually?  Close enough?

Suppose the successor state for a symbol is uniformly random (and they do behave pretty randomly).  Have a look at the table above.  If you bump the slot usage for a symbol by 1, the effect is to split one of those 5-bit ranges into two 4-bit ranges, but those 4-bit ranges cover the same range of successor states, so really that fraction of the successor state ranges was reduced by 1 bit in cost.  The problem is that all of those 5-bit ranges are the same size, so each time you add 1 slot, it keeps reducing the bits-per-symbol by the same amount, not an amount that's changing according to a logarithm curve.  It continues doing this until it hits a power of 2, at which point it has to start splitting 4-bit ranges into 3-bit ranges.


But that means the ideal slot to bump the probability of is the same until it hits a power-of-2 usage, resulting in a table that contains only power-of-2 values except for at most one probability - and all of those powers of 2 have integral bit usage, which is just Huffman.

This seemed very problematic so I had to run a few tests, and the answer was that FSE doesn't have this problem because given uniform successor states, FSE will produce predecessor states that are disproportionately lower values.  (This is why the less-than-one distribution works in the first place.)  It's somewhat easy to see why: Lower-numbered states take up more numeric range.

Doing baseline distribution before scattering breaks that assumption though, and it makes the predecessor states uniformly distributed.


Some further tests confirmed my suspicions: If you use a properly-scaled probability distribution with cell properties determined before scattering, then the result is worse than Huffman.  It's only better than Huffman if the cell properties are determined AFTER scattering.

That means that in order to do this right, it was necessary to sort the entries, which I don't think there is any particularly good algorithm for.

But tests showed it was only a small amount of overhead?

Yeah, but Huffman is only a small amount of overhead too.  FSE and arithmetic coding make it possible to get fractional bit usage, but the benefit of that depends a lot on the distribution of values.  The fewer bits used by the most-probable symbols, the more likely it is to help, because a fractional bit of accuracy is a bigger portion of the values being encoded.


The silver lining

Ultimately, this realization is pretty lethal to using FSE on GPU, since the remaining options for table construction are quite bad, and the table is a problem in the first place.

The good news is that rANS should still work.  In fact, the way GDEFLATE decodes Huffman codes is by using binary search within a vector register, and rANS symbol lookup would work the same way.

So, Gstd isn't dead yet (it will be dead if it turns out that the larger blocks are due to format inefficiencies instead of libdeflate being really good), but it will not be using FSE, and I think the work of getting FSE to work on GPU compute was unfortunately wasted.

No comments:

Post a Comment