arXiv:2604.15408v2 Announce Type: replace-cross
Abstract: Token pruning methods for Vision Transformers (ViTs) promise quadratic reductions in attention FLOPs by dropping uninformative patches. Yet standard variable-length attention APIs — including FlashAttention-2’s varlen and PyTorch’s NestedTensor SDPA — fail to translate these savings into proportional wall-clock gains at the short post-pruning sequence lengths typical of ViTs ($leq$197 tokens). We identify a dispatch-overhead bottleneck: at these lengths, host-side kernel dispatch consumes $sim$50,$mu$s regardless of workload, exceeding the actual GPU compute time at moderate-to-high pruning rates. We present a lightweight bidirectional Triton attention kernel whose dispatch floor is $sim$24,$mu$s — roughly 2.17$times$ lower than FlashAttention-2 varlen — allowing pruning savings to become visible in wall-clock time. Integrated into a complete pack-attend-unpack pipeline and evaluated on an NVIDIA RTX 4000 Ada Generation GPU, our system achieves 1.88$times$ end-to-end throughput over padded PyTorch SDPA at standard 224$times$224 inputs, scaling to 2.51$times$ at 384$times$384. Against FlashAttention-2 varlen — the strongest baseline — our kernel delivers 9-12% higher throughput at serving batch sizes (BS=1-4), and 2.17$times$ lower kernel latency at 80% token pruning. Numerical correctness is verified with max absolute logit difference $<$0.004 and bit-exact top-1 predictions.

Subscribe for Updates

Copyright 2025 dijee Intelligence Ltd.   dijee Intelligence Ltd. is a private limited company registered in England and Wales at Media House, Sopers Road, Cuffley, Hertfordshire, EN6 4RY, UK registration number 16808844