Article·Oct 27, 2025

Accelerating Streaming STT Inference Through Custom Kernels

With Flux, we wanted to offer streaming STT suitable for voice agents that was low latency, without sacrificing accuracy or concurrency. This led us to a tricky problem in cache management. Here's our nifty solution! Warning: In-depth technical details ahead.

10 min read

By Josh Gevirtz

Last Updated

Latency, Concurrency and True Streaming

For deep learning models that process sequential data, the ability to reuse cached results from prior computation is critical to achieving low latency. Attention-based models, for instance, use key-value caching to avoid recomputing large matrix multiplications as new queries are processed. Achieving high concurrency, meanwhile, depends on the ability to batch multiple requests together, i.e., combining features or states associated with each request into a single tensor that can be processed in parallel by the model. While latency and concurrency are generally at odds, they are perhaps most so in a streaming setting, where it is essentially impossible to maintain a coherent batch over an extended period. Even if you can batch together a group of requests at a certain time, there is no guarantee that you will continue to receive data from all of them at the same rate and that they will remain synchronized.

One popular solution to this problem employed for many state-of-the-art STT models, such as the various “streaming Whisper” variants, is to replace “true streaming” inference with what one might call “fast batching.” Instead of continually processing audio as it streams in, you accumulate audio until you have an approximately independent “chunk” delineated by, e.g., brief pauses in speech as detected by a lightweight Voice Activity Detector (VAD). Chunks from different streams can be batched and inferenced together, supporting reasonable concurrency. Then, since you assume chunks are independent, any intermediate results for the entire batch can be discarded. As such, this workflow is more similar to a typical batch workflow, but carried out at high frequency.

Fast batching is well suited to highly-accurate transformer-based autoregressive encoder-decoder models like Whisper, which are trained to transcribe finite chunks as opposed to stream in audio and stream out tokens. However, it has major drawbacks. The first is latency; since the entire computation is held off until a suitable break in the audio, customers need to wait after a chunk is created for inference to start. The second is limited context; to avoid large latency associated with having to inference a very long chunk of audio, fast batching typically slices audio into fairly short chunks, with no information shared between chunks. This makes it difficult to achieve good performance on difficult tasks that require more context, such as end-of-turn (EoT) detection and consistent punctuation. Correspondingly, combining “fast batch” STT with EoT detection leads to complex, expensive pipelines involving VADs, careless Whispers, and separate turn detection models. Finally, even with short chunks, fast batching provides relatively infrequent updates, depriving users of, e.g., ongoing visual cues indicating how their speech is being processed.

With Flux, we wanted to offer customers a low latency, “true streaming” product that was still capable of achieving high accuracy and concurrency expected of today’s state-of-the-art models, such as Nova-3. However, that meant we needed to solve the problem of the de-synchronization that necessarily occurs between batch elements during streaming. Said another way, we needed a way to carry out inference over batches that are both variable length (i.e., not all elements advance at the same cadence) and partial (i.e., not all batch element states are updated at the same time).

Caching with Incomplete Batches in Flux

In standard inference workflows where all batch elements are synchronized, caches can be updated by simply appending the results of new computation to an existing cache. For instance, for a typical cache of dimension B x T x D (B = batch, T = timestep, D = hidden state dimension), the update operation could be as simple as

cache = torch.cat([cache, new_state], dim=1)

To enable updating batch elements independently, such that computation can proceed as soon as new data comes in, Flux instead uses pre-allocated caches, corresponding to the overall batch size, that can be updated using advanced indexing in torch. For instance, as opposed to the more typical “concatenation”-based update above, we update only a subset of indices based on their current lengths,

cache[idxes, lengths - 1] = new_state 

Unfortunately, while updating the cache is (a bit of an indexing headache but otherwise) straightforward, accessing and using a subset of the cache efficiently is not. Using existing torch kernels, we need to index into the cache and extract the required rows, and then mask any computation results for cache entries that are invalid to account for the variable lengths. For an attention-based layer, this looks like

k_batch, v_batch = k_cache[idxes], v_cache[idxes]
attn_mask = torch.arange(max_len).view(1, -1) < lengths.view(-1, 1)
x = torch.nn.functional.scaled_dot_product_attention(q, k_batch, v_batch, attn_mask=attn_mask)

This is an unnecessarily expensive workflow. We extract {k,v}_batch first because the torch scaled_dot_product_attention API expects inputs to have their batch dimensions aligned. But, when profiling the above for large pre-allocated cache sizes, we found that indexing into the caches each individually took almost as long as the actual SDPA calculation! But this is an unnecessary cost; we do not actually need the copies of the tensors produced in the first line, since we know from the idxes and lengths which elements in the cache are needed for the attention calculation. So, it is feasible and preferable to access them directly. Another less obvious cost is hiding within the last line; while the torch scaled_dot_product_attention kernel is highly efficient, in the interest of generality it applies the attention mask after attention weight computation, with no special handling of masks such as ours of the form [True...True False...False]. This means that, by using this kernel with the large pre-allocated k_cache, we are frequently computing many unnecessary q-k dot products, particularly for batch elements with length << cache_size

This approach was sufficiently costly that, if we were to rely on native torch kernels, we would have to sacrifice concurrency or model capacity to reach our target “realtime” latency. But we were unwilling to do either; with Flux, we wanted to offer Deepgram customers an improvement in both streaming latency and concurrency while still maintaining the high accuracy for which our models are known. So, we put on our adult pants and delved deeper into the world of CUDA optimization.

Custom Kernels to the Rescue

How could we avoid the unnecessary costs described above? Custom kernels! A general strategy for developing custom kernels is to fuse specific operations; instead of carrying our operations atomically, and potentially paying memory, transfer or launch costs multiple times, you carry out multiple related operations at once so that you take advantage of, e.g., particular tensors being loaded into GPU shared memory. This is exactly the strategy we employed; we develop kernels that fuse the aforementioned cache access indexing operations with the calculation of SDPA.

Specifically, instead of first extracting the required rows and the performing the calculation, we access the required cache elements and use them directly in the calculation, avoiding an intermediate copy. In addition, instead of accessing the whole “row” associated with an index, we only access up to the position required, as specific by the length.

We were also able to exploit one optimization our problem permits that the “general” torch kernel cannot. Our inference workflow generally does not need to support any arbitrary shape of q, but only the shapes that appear in our specific inference workflow (e.g., if q.size() = B, T, D, a typical autoregressive decoding step might always have T = 1). This enabled us to deploy GPU resources to maximal effect for the required shapes. (We also have a simplicity advantage; by designing kernels only for inference, we do not need to address all the possible conditions associated with training and autograd.)

So, compared to the above, the API for our version of the SDPA kernel would look something like

x = fused_sdpa_indexing_kernel(q, k_cache, v_cache, indexes, lengths)

The diagram below helps visualize the difference between our kernel and the workflow using native torch APIs

Whereas the native torch approach (i) extracts cache rows as a first, separate step, and (ii) feeds entire rows into the attention calculation, with masking applied after the fact, our kernel feeds the required cache elements directly into the calculation, and only the ones that are required.

Results

Compared to the torch approach described above, we gain our most significant speedup from avoiding unnecessary indexing operations, essentially cutting our compute time by 2/3 for attention-based layers. But, truncating the attention calculation based on lengths is very helpful as well. The plot below schematically represents the performance of our kernel relative to torch scaled_dot_product_attention as a function of length for a fixed cache size:

Since the torch kernel always operates on the entire key tensor, including irrelevant indices ≥ length, it achieves relatively constant latency. By comparison, our kernel achieves better latency across a broad range of lengths. Though hats off to all those constantly working on improving and optimizing torch; in the limit where almost all of the cache is relevant (length ~ cache_size), scaled_dot_product_attention is marginally faster.

Overall, our custom kernel allowed us to achieve an over 80% improvement in concurrency with the same latency relative to the “native torch” approach. Worth squeezing into those adult pants!