SpanAttention
SpanAttention implements a high-performance decode-phase attention with paged KV cache for LLM inference on CUDA-enabled devices.
SpanAttention supports general group-query attention (GQA), which also covers the cases of multi-head attention (MHA) and multi-query attention (MQA). For input and output, SpanAttention supports the commonly used data types including FP32 (float), FP16 (half), and BF16 (bfloat16). For KV cache, SpanAttention supports optional 8-bit and 4-bit integer quantization.
To achieve high performance, SpanAttention supports batch execution on multiple queries. It also leverages Tensor Cores to further accelerate the GQA computation.
Table of contents
Usage
For API documentation, please refer to the interface header file.
Basic example
The following example gives a basic use case of SpanAttention for a single kernel launch.
The supported input and output data types include span::DataType::FP32, span::DataType::FP16, and span::DataType::BF16.
For supported KV cache quantization modes, see KV cache quantization.
#include <cstdlib>
#include <stdexcept>
#include <cuda_runtime.h>
#include <spanattn/span_attn.h>
// check status code
#define SA_CHECK(expr) \
do { \
span::SaStatus __error_code = (expr); \
if (__error_code != span::SaStatus::SUCCESS) { \
throw std::runtime_error(span::GetErrorString(__error_code)); \
} \
} while (0)
// suppose input/output data type is BF16
// suppose KV cache is not quantized
span::DataType dtype = span::DataType::BF16;
span::QuantMode qmode = span::QuantMode::NONE;
// first, create a handle
span::SpanAttnHandle_t handle{nullptr};
SA_CHECK(span::CreateHandle(&handle, dtype, qmode, ...));
if (handle == nullptr) {
throw std::runtime_error("Failed to create SpanAttn handle");
}
// next, alloc workspace for span attention kernel
void* device_workspace{nullptr};
void* host_workspace{nullptr};
size_t device_ws_size{0};
size_t host_ws_size{0};
SA_CHECK(span::GetDeviceWorkspaceSize(&device_ws_size, handle));
SA_CHECK(span::GetHostWorkspaceSize(&host_ws_size, handle));
cudaMalloc(&device_workspace, device_ws_size);
host_workspace = std::malloc(host_ws_size);
// then, run the kernel
SA_CHECK(span::Run(...));
// finally, destroy the handle and free workspace
SA_CHECK(span::DestroyHandle(handle));
cudaFree(device_workspace);
std::free(host_workspace);
Asynchronous kernel launch and performance considerations
In real-world LLM inference, SpanAttention kernel is expected to be launched asynchronously, just like other CUDA kernels. In this case, synchronous workspace allocation is not preferred.
Thus, we recommend allocating workspace at the initialization phase, using a handle created with maximum possible values of the parameters in your use case. Then in the execution loop, reuse the workspace and launch kernels fully asynchronously.
KV cache quantization
SpanAttention supports quantization of KV cache.
For now, 8-bit integer (span::QuantMode::I8) and 4-bit unsigned integer (span::QuantMode::U4) quantization modes are supported.
If span::QuantMode::NONE is used, KV cache is not quantized and will be of the same type as input and output.
KV cache spans
SpanAttention assumes paged KV cache, which is useful for LLM inference with long context. A KV cache page is called a span in SpanAttention. For now, SpanAttention does not provide APIs for span management. So users are expected to manage spans manually.
On the one hand, users need to manage the spans with the span size calculated as described below. On the other hand, users need to pass the spans as arguments to SpanAttention API.
Span size
Let spanLen be the number of tokens in a span.
For a query of length queryLen (i.e., the number of tokens in the cache), the number of spans for this query is ceil(queryLen / spanLen).
Let FT be the type of input and output (can be float, half, or bfloat16).
The size of a span in bytes is computed as follows:
Not quantized: each span takes up
spanLen * #kvHeads * headSize * sizeof(FT)bytes.Quantized: let
QTbe the type of quantized KV cache (can be 8-bit or 4-bit integer), and each span takes upspanLen * #kvHeads * headSize * sizeof_bits(QT) / 8 + 2 * spanLen * #kvHeads * sizeof(float)bytes.
Here, #kvHeads is the number of KV heads in GQA, and headSize is the size of each head.
For now, headSize is always 128, which is practically the default setting for LLM models.
Spans should be allocated in device memory.
Span pointer array
Given a batch of N queries, for each one of K cache and V cache, the spans should be passed to APIs as a 2D array of pointers, with the first dimension being the batch index, and the second dimension being the span index inside each query.
Suppose the longest query has maxQueryLen tokens, then the second dimension of the array should be at least ceil(maxQueryLen / spanLen), which is passed as nSpansPerRequest to span::Create.
Thus the expected shape of the kSpanArray argument passed to span::Run is [N, nSpansPerRequest].
So is the vSpanArray argument.
Installation and integration
Requirements
CMake 3.21 or above
C++ compiler with C++17 support
CUDA 12.0 or above
Ninja is recommended
CUTLASS 3.5 or above if using external CUTLASS
Build and install
First configure the project:
cd span-attention
cmake -S . -B build [-G <generator>] [options]
Tested generators are:
Unix Makefiles– default generator;Ninja– recommended.
Options include:
-DSPANATTN_CUDA_ARCHS="arch1;arch2;..."– target CUDA architecture compute capabilities (SM versions), default is"75;80;90a";-DSPANATTN_ENABLE_TEST=ON|OFF– build tests, default isON;-DSPANATTN_ENABLE_FP16=ON|OFF– enable FP16 (half) support, default isON;-DSPANATTN_ENABLE_BF16=ON|OFF– enable BF16 (bfloat16) support, default isON;-DSPANATTN_STATIC_CUDART=ON|OFF– statically link CUDA runtime, default isOFF;-DSPANATTN_EXTERNAL_CUTLASS=ON|OFF– use external CUTLASS, default isOFF;-DCMAKE_BUILD_TYPE=type– build type, defaulttypeisRelease, for debugging useDebug;-DCMAKE_INSTALL_PREFIX=directory– install directory.
Then build and install:
cmake --build build [--target install]
When using Ninja, build is running in parallel by default.
Otherwise, you can specify the number of parallel jobs with -j option.
For detailed output, use -v option.
Integrate with CMake
Instead of installing, you can integrate SpanAttention with your project using CMake ExternalProject.
An example is provided below.
This example assumes CUTLASS is already integrated with your project as project_cutlass, installed in path CUTLASS_INSTALL.
# suppose your project specifies CUDA archs with CMAKE_CUDA_ARCHITECTURES
set(SPANATTN_CUDA_ARCHS ${CMAKE_CUDA_ARCHITECTURES})
set(SPANATTN_EXTERNAL_CUTLASS ON)
set(SPANATTN_ENABLE_TEST OFF)
# set your own paths
set(SPANATTN_SOURCE_DIR /path/to/source/codes/of/span-attention)
set(SPANATTN_INSTALL /path/to/install/span-attention/in/your/project)
include(ExternalProject)
ExternalProject_Add(project_spanattn
PREFIX ${CMAKE_CURRENT_BINARY_DIR}/span-attention
SOURCE_DIR ${SPANATTN_SOURCE_DIR}
DEPENDS project_cutlass
CMAKE_GENERATOR "Ninja"
BUILD_COMMAND ${CMAKE_COMMAND} --build . -v
CMAKE_CACHE_ARGS
-DSPANATTN_CUDA_ARCHS:STRING=${SPANATTN_CUDA_ARCHS}
CMAKE_ARGS
-DSPANATTN_EXTERNAL_CUTLASS=${SPANATTN_EXTERNAL_CUTLASS}
-DSPANATTN_ENABLE_TEST=${SPANATTN_ENABLE_TEST}
-DCMAKE_INSTALL_PREFIX=${SPANATTN_INSTALL}
-DCUTLASS_INSTALL_PATH=${CUTLASS_INSTALL}
)
unset(SPANATTN_CUDA_ARCHS)
unset(SPANATTN_EXTERNAL_CUTLASS)
unset(SPANATTN_ENABLE_TEST)
Suppose you want to use SpanAttention as a static library in your project, you can import it as spanattn::spanattn_static:
file(MAKE_DIRECTORY ${SPANATTN_INSTALL}/include)
add_library(spanattn::spanattn_static STATIC IMPORTED)
add_dependencies(spanattn::spanattn_static project_spanattn)
set_target_properties(spanattn::spanattn_static
PROPERTIES
IMPORTED_LOCATION ${SPANATTN_INSTALL}/lib/libspanattn.a
INTERFACE_INCLUDE_DIRECTORIES ${SPANATTN_INSTALL}/include
)
License
The SpanAttention source code is licensed under the Apache 2.0 license, and you can find the full text of the license in the root of the repository.