Optimization for CPU Causal Flash Attention (integrated into Qwen3) (#3254)
add flash attn to qwen3
feature flash-attn
flash generative true
no mask in cpu flash
add causal loop-bound optimization to cpu_flash_attention
attempt at hybrid
working tiling
working but poor performance tiled flash
factorized attention cpu flash
logging
causal cpu flash
depracation warning
depracation warning specific
formatting
interleaved will live in cpu_flash
clippy resolve
AttnMask: remove unnecessary lifetime
fix: flash-attn KV cache masking during decode
single batch CPU flash + integrate CPU varlen
fail B>1 my cpu flash
benchmark qwen3
clean up
varlen force
dispatch logic smolLM3
interleaved cpu
interleaved smollm3
memory leak
back to exact exp
simplfy
cargo fmt
clippy
errors; fmt
cpu flash comment
remove bench script
merge main + resolve; default CPU flash (note that quantized varlen CPU flash not implemented) and remove –use-flash-attn CLI (note there is no path to standard attention for qwen3 and quantized qwen3)
debug cfg lazylock only
remove unused cli
programatic commentary address
proposed comments for internal use
Co-authored-by: michaelfeil me@michaelfeil.eu Co-authored-by: ivarflakstad 69173633+ivarflakstad@users.noreply.github.com
版权所有:中国计算机学会技术支持:开源发展技术委员会
京ICP备13000930号-9
京公网安备 11010802032778号
candle
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: whisper, LLaMA2, T5, yolo, Segment Anything.
Get started
Make sure that you have
candle-corecorrectly installed as described in Installation.Let’s see how to run a simple matrix multiplication. Write the following to your
myapp/src/main.rsfile:cargo runshould display a tensor of shapeTensor[[2, 4], f32].Having installed
candlewith Cuda support, simply define thedeviceto be on GPU:For more advanced examples, please have a look at the following section.
Check out our examples
These online demos run entirely in your browser:
We also provide some command line based examples using state of the art models:
Run them using commands like:
In order to use CUDA add
--features cudato the example command line. If you have cuDNN installed, use--features cudnnfor even more speedups.There are also some wasm examples for whisper and llama2.c. You can either build them with
trunkor try them online: whisper, llama2, T5, Phi-1.5, and Phi-2, Segment Anything Model.For LLaMA2, run the following command to retrieve the weight files and start a test server:
And then head over to http://localhost:8081/.
Useful External Resources
candle-tutorial: A very detailed tutorial showing how to convert a PyTorch model to Candle.candle-lora: Efficient and ergonomic LoRA implementation for Candle.candle-lorahasout-of-the-box LoRA support for many models from Candle, which can be found here.
candle-video: Rust library for text-to-video generation (LTX-Video and related models) built on Candle, focused on fast, Python-free inference.optimisers: A collection of optimisers including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.candle-vllm: Efficient platform for inference and serving local LLMs including an OpenAI compatible API server.candle-ext: An extension library to Candle that provides PyTorch functions not currently available in Candle.candle-coursera-ml: Implementation of ML algorithms from Coursera’s Machine Learning Specialization course.kalosm: A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.candle-sampling: Sampling techniques for Candle.gpt-from-scratch-rs: A port of Andrej Karpathy’s Let’s build GPT tutorial on YouTube showcasing the Candle API on a toy problem.candle-einops: A pure rust implementation of the python einops library.atoma-infer: A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.llms-from-scratch-rs: A comprehensive Rust translation of the code from Sebastian Raschka’s Build an LLM from Scratch book.vllm.rs: A minimalist vLLM implementation in Rust based on Candle.If you have an addition to this list, please submit a pull request.
Features
How to use
Cheatsheet:
torch.Tensor([[1, 2], [3, 4]])Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?torch.zeros((2, 2))Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?tensor[:, :4]tensor.i((.., ..4))?tensor.view((2, 2))tensor.reshape((2, 2))?a.matmul(b)a.matmul(&b)?a + b&a + &btensor.to(device="cuda")tensor.to_device(&Device::new_cuda(0)?)?tensor.to(dtype=torch.float16)tensor.to_dtype(&DType::F16)?torch.save({"A": A}, "model.bin")candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?weights = torch.load("model.bin")candle::safetensors::load("model.safetensors", &device)Structure
Tensorstruct definitionFAQ
Why should I use Candle?
Candle’s core goal is to make serverless inference possible. Full machine learning frameworks like PyTorch are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight binaries.
Secondly, Candle lets you remove Python from production workloads. Python overhead can seriously hurt performance, and the GIL is a notorious source of headaches.
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like safetensors and tokenizers.
Other ML frameworks
dfdx is a formidable crate, with shapes being included in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat. However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.
We’re leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each other.
burn is a general crate that can leverage multiple backends so you can choose the best engine for your workload.
tch-rs Bindings to the torch library in Rust. Extremely versatile, but they bring in the entire torch library into the runtime. The main contributor of
tch-rsis also involved in the development ofcandle.Common Errors
Missing symbols when compiling with the mkl feature.
If you get some missing symbols when compiling binaries/tests using the mkl or accelerate features, e.g. for mkl you get:
or for accelerate:
This is likely due to a missing linker flag that was needed to enable the mkl library. You can try adding the following for mkl at the top of your binary:
or for accelerate:
Cannot run the LLaMA examples: access to source requires login credentials
This is likely because you’re not permissioned for the LLaMA-v2 model. To fix this, you have to register on the huggingface-hub, accept the LLaMA-v2 model conditions, and set up your authentication token. See issue #350 for more details.
Docker build
When building CUDA kernels inside a Dockerfile, nvidia-smi cannot be used to auto-detect compute capability.
You must explicitly set CUDA_COMPUTE_CAP, for example:
Compiling with flash-attention fails
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
Linking error on windows when running rustdoc or mdbook tests
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
Extremely slow model load time with WSL
This may be caused by the models being loaded from
/mnt/c, more details on stackoverflow.Tracking down errors
You can set
RUST_BACKTRACE=1to be provided with backtraces when a candle error is generated.CudaRC error
If you encounter an error like this one
calledResult::unwrap()on anErrvalue: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.c:\Windows\System32\nvcuda.dll->cuda.dllc:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll->cublas.dllc:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll->curand.dll