Fully fused networks vs. TensorFlow v2.5.0 w/ XLA. Measured on 64 (solid line) and 128 (dashed line) neurons wide multi-layer perceptrons on an RTX 3090. Generated by benchmarks/bench_ours.cu and benchmarks/bench_tensorflow.py using data/config_oneblob.json.
Usage
Tiny CUDA neural networks have a simple C++/CUDA API:
#include <tiny-cuda-nn/common.h>
// Configure the model
nlohmann::json config = {
{"loss", {
{"otype", "L2"}
}},
{"optimizer", {
{"otype", "Adam"},
{"learning_rate", 1e-3},
}},
{"encoding", {
{"otype", "HashGrid"},
{"n_levels", 16},
{"n_features_per_level", 2},
{"log2_hashmap_size", 19},
{"base_resolution", 16},
{"per_level_scale", 2.0},
}},
{"network", {
{"otype", "FullyFusedMLP"},
{"activation", "ReLU"},
{"output_activation", "None"},
{"n_neurons", 64},
{"n_hidden_layers", 2},
}},
};
using namespace tcnn;
auto model = create_from_config(n_input_dims, n_output_dims, config);
model->set_jit_fusion(supports_jit_fusion()); // Optional: accelerate with JIT fusion
// Train the model (batch_size must be a multiple of tcnn::BATCH_SIZE_GRANULARITY)
GPUMatrix<float> training_batch_inputs(n_input_dims, batch_size);
GPUMatrix<float> training_batch_targets(n_output_dims, batch_size);
for (int i = 0; i < n_training_steps; ++i) {
generate_training_batch(&training_batch_inputs, &training_batch_targets); // <-- your code
float loss;
model.trainer->training_step(training_batch_inputs, training_batch_targets, &loss);
std::cout << "iteration=" << i << " loss=" << loss << std::endl;
}
// Use the model
GPUMatrix<float> inference_inputs(n_input_dims, batch_size);
generate_inputs(&inference_inputs); // <-- your code
GPUMatrix<float> inference_outputs(n_output_dims, batch_size);
model.network->inference(inference_inputs, inference_outputs);
JIT fusion
JIT fusion is a new, optional feature with tiny-cuda-nn v2.0 and later.
It is almost always recommended to enable automatic JIT fusion for a performance boost of 1.5x to 2.5x, depending on the model and GPU.
Newer GPUs exhibit larger speedups.
If your model has very large hash grids (~20 million+ parameters) or MLPs (layer sizes larger than 128 neurons), or when your GPU is an RTX 3000 series or earlier, JIT fusion can slow down training.
Rarely inference, too.
It this case, it is recommended to try enabling JIT fusion separately for training and inference to measure whether it is faster.
Please open an issue if you encounter a slowdown in a different situation or other problems with JIT fusion enabled.
Automatic JIT fusion
To enable JIT fusion, set the jit_fusion property of your model to true.
All future uses of the model, whether inference or training, will then use JIT mode.
Note that if there is an error during JIT compilation, a warning will be emitted and JIT compilation mode automatically turned off.
Your code will still run using the tiny-cuda-nn 1.X code path.
auto model = tcnn::create_from_config(...);
model->set_jit_fusion(tcnn::supports_jit_fusion()); // Enable JIT if the system supports it
JIT fusion can also be enabled via the PyTorch bindings but the speed-up will be lower, particularly during training.
This is because the JIT compiler does not have access to the whole compute graph and can therefore fuse and optimize less.
import tinycudann as tcnn
model = tcnn.NetworkWithInputEncoding(...) # Or any other tcnn model
model.jit_fusion = tcnn.supports_jit_fusion() # Enable JIT if the system supports it
Manual JIT fusion
Even larger speed-ups are possible when applications integrate more tightly with JIT fusion.
For example, Instant NGP achieves a 5x speedup by fusing the entire NeRF ray marcher into a single kernel.
JIT fusion works by converting a given tiny-cuda-nn model to a CUDA device function and then compiling it into a kernel using CUDA’s runtime compilation (RTC) feature.
To integrate a tiny-cuda-nn model with a larger kernel in your app, you need to
turn your kernel into a string,
prepend the tiny-cuda-nn model’s device function,
pass the result to tiny-cuda-nn’s runtime compilation API.
Here is an example that implements a minimal kernel using a tiny-cuda-nn model with 32 input dimensions and 16 output dimensions:
#include <tiny-cuda-nn/rtc_kernel.h>
auto model = tcnn::create_from_config(32 /* input dims */, 16 /* output dims */, ...);
auto fused_kernel = tcnn::CudaRtcKernel(
"your_kernel",
fmt::format(R"
{MODEL_DEVICE_FUNCTION}
__global__ void your_kernel(...) {
// Get input to model from either registers or memory.
tcnn::hvec<32> input = ...;
// Call tiny-cuda-nn model. All 32 threads of the warp must be active here.
tcnn::hvec<16> output = model_fun(nerf_in, params);
// Do something with the model output.
}",
fmt::arg("MODEL_DEVICE_FUNCTION", model->generate_device_function("model_fun")),
)
);
uint32_t blocks = 1;
uint32_t threads = 128; // Must be multiple of 32 for neural networks to work.
uint32_t shmem_size = 0; // Can be any size that your_kernel needs.
cudaStream_t stream = nullptr; // Can be any stream.
fused_kernel.launch(blocks, threads, shmem_size, stream, ... /* params of your_kernel */);
And here is Instant NGP’s NeRF integration with the JIT compiler for reference:
The fully fused MLP component of this framework requires a very large amount of shared memory in its default configuration. It will likely only work on an RTX 3090, an RTX 2080 Ti, or higher-end GPUs. Lower end cards must reduce the n_neurons parameter or use the CutlassMLP (better compatibility but slower) instead.
If you are using Linux, install the following packages
sudo apt-get install build-essential git
We also recommend installing CUDA in /usr/local/ and adding the CUDA installation to your PATH.
For example, if you have CUDA 12.6.3, add the following to your ~/.bashrc
If compilation fails inexplicably or takes longer than an hour, you might be running out of memory. Try running the above command without -j in that case.
PyTorch extension
tiny-cuda-nn comes with a PyTorch extension that allows using the fast MLPs and input encodings from within a Python context.
These bindings can be significantly faster than full Python implementations; in particular for the multiresolution hash encoding.
The overheads of Python/PyTorch can nonetheless be extensive if the batch size is small.
For example, with a batch size of 64k, the bundled mlp_learning_an_image example is ~2x slower through PyTorch than native CUDA.
With a batch size of 256k and higher (default), the performance is much closer.
Begin by setting up a Python 3.X environment with a recent, CUDA-enabled version of PyTorch. Then, invoke
Alternatively, if you would like to install from a local clone of tiny-cuda-nn, invoke
tiny-cuda-nn$ cd bindings/torch
tiny-cuda-nn/bindings/torch$ python setup.py install
By default, the extension automatically enables half precision (FP16) on GPUs with good support (Volta, Turing, Ampere, etc.) and disables it on older architectures or those with slow FP16 (e.g., Pascal/GTX 10-series).
If you wish to override this behavior (e.g., to force FP16 on unsupported hardware or disable it for debugging), set the TCNN_HALF_PRECISION environment variable before installation:
Same as above, but normalized by the luminance of the network prediction. Only applicable when network prediction is RGB. Used in Neural Radiance Caching [Müller et al. 2021].
Cross Entropy
include/tiny-cuda-nn/losses/cross_entropy.h
Standard cross entropy loss. Only applicable when the network prediction is a PDF.
Variance
include/tiny-cuda-nn/losses/variance_is.h
Standard variance loss. Only applicable when the network prediction is a PDF.
Wraps another optimizer and computes a linear average of the weights over the last N iterations. The average is used for inference only (does not feed back into training).
Batched
include/tiny-cuda-nn/optimizers/batched.h
Wraps another optimizer, invoking the nested optimizer once every N steps on the averaged gradient. Has the same effect as increasing the batch size but requires only a constant amount of memory.
Composite
include/tiny-cuda-nn/optimizers/composite.h
Allows using several optimizers on different parameters.
EMA
include/tiny-cuda-nn/optimizers/average.h
Wraps another optimizer and computes an exponential moving average of the weights. The average is used for inference only (does not feed back into training).
Please feel free to make a pull request if your publication or software is not listed.
Acknowledgments
Special thanks go to the NRC authors for helpful discussions and to Nikolaus Binder for providing part of the infrastructure of this framework, as well as for help with utilizing TensorCores from within CUDA.
Tiny CUDA Neural Networks
This is a small, self-contained framework for training and querying neural networks. Most notably, it contains a lightning fast “fully fused” multi-layer perceptron (technical paper), a versatile multiresolution hash encoding (technical paper), as well as support for various other input encodings, losses, and optimizers.
Performance
benchmarks/bench_ours.cuandbenchmarks/bench_tensorflow.pyusingdata/config_oneblob.json.Usage
Tiny CUDA neural networks have a simple C++/CUDA API:
JIT fusion
JIT fusion is a new, optional feature with tiny-cuda-nn v2.0 and later. It is almost always recommended to enable automatic JIT fusion for a performance boost of 1.5x to 2.5x, depending on the model and GPU. Newer GPUs exhibit larger speedups.
If your model has very large hash grids (~20 million+ parameters) or MLPs (layer sizes larger than 128 neurons), or when your GPU is an RTX 3000 series or earlier, JIT fusion can slow down training. Rarely inference, too. It this case, it is recommended to try enabling JIT fusion separately for training and inference to measure whether it is faster.
Please open an issue if you encounter a slowdown in a different situation or other problems with JIT fusion enabled.
Automatic JIT fusion
To enable JIT fusion, set the
jit_fusionproperty of your model totrue. All future uses of the model, whether inference or training, will then use JIT mode. Note that if there is an error during JIT compilation, a warning will be emitted and JIT compilation mode automatically turned off. Your code will still run using the tiny-cuda-nn 1.X code path.JIT fusion can also be enabled via the PyTorch bindings but the speed-up will be lower, particularly during training. This is because the JIT compiler does not have access to the whole compute graph and can therefore fuse and optimize less.
Manual JIT fusion
Even larger speed-ups are possible when applications integrate more tightly with JIT fusion. For example, Instant NGP achieves a 5x speedup by fusing the entire NeRF ray marcher into a single kernel.
JIT fusion works by converting a given tiny-cuda-nn model to a CUDA device function and then compiling it into a kernel using CUDA’s runtime compilation (RTC) feature.
To integrate a tiny-cuda-nn model with a larger kernel in your app, you need to
Here is an example that implements a minimal kernel using a tiny-cuda-nn model with 32 input dimensions and 16 output dimensions:
And here is Instant NGP’s NeRF integration with the JIT compiler for reference:
Example: learning a 2D image
We provide a sample application where an image function (x,y) -> (R,G,B) is learned. It can be run via
producing an image every couple of training steps. Each 1000 steps should take a bit over 1 second with the default configuration on an RTX 4090.
Requirements
n_neuronsparameter or use theCutlassMLP(better compatibility but slower) instead.If you are using Linux, install the following packages
We also recommend installing CUDA in
/usr/local/and adding the CUDA installation to your PATH. For example, if you have CUDA 12.6.3, add the following to your~/.bashrcCompilation (Windows & Linux)
Begin by cloning this repository and all its submodules using the following command:
Then, use CMake to build the project: (on Windows, this must be in a developer command prompt)
If compilation fails inexplicably or takes longer than an hour, you might be running out of memory. Try running the above command without
-jin that case.PyTorch extension
tiny-cuda-nn comes with a PyTorch extension that allows using the fast MLPs and input encodings from within a Python context. These bindings can be significantly faster than full Python implementations; in particular for the multiresolution hash encoding.
Begin by setting up a Python 3.X environment with a recent, CUDA-enabled version of PyTorch. Then, invoke
Alternatively, if you would like to install from a local clone of tiny-cuda-nn, invoke
By default, the extension automatically enables half precision (FP16) on GPUs with good support (Volta, Turing, Ampere, etc.) and disables it on older architectures or those with slow FP16 (e.g., Pascal/GTX 10-series).
If you wish to override this behavior (e.g., to force FP16 on unsupported hardware or disable it for debugging), set the TCNN_HALF_PRECISION environment variable before installation:
Disable FP16: 0 Enable FP16: 1
Example:
Upon success, you can use tiny-cuda-nn models as in the following example:
See
samples/mlp_learning_an_image_pytorch.pyfor an example.Components
Following is a summary of the components of this framework. The JSON documentation lists configuration options.
src/fully_fused_mlp.cusrc/cutlass_mlp.cuinclude/tiny-cuda-nn/encodings/composite.hinclude/tiny-cuda-nn/encodings/frequency.hinclude/tiny-cuda-nn/encodings/grid.hinclude/tiny-cuda-nn/encodings/identity.hinclude/tiny-cuda-nn/encodings/oneblob.hinclude/tiny-cuda-nn/encodings/spherical_harmonics.hinclude/tiny-cuda-nn/encodings/triangle_wave.hinclude/tiny-cuda-nn/losses/l1.hinclude/tiny-cuda-nn/losses/l1.hinclude/tiny-cuda-nn/losses/mape.hinclude/tiny-cuda-nn/losses/smape.hinclude/tiny-cuda-nn/losses/l2.hinclude/tiny-cuda-nn/losses/relative_l2.hinclude/tiny-cuda-nn/losses/relative_l2_luminance.hinclude/tiny-cuda-nn/losses/cross_entropy.hinclude/tiny-cuda-nn/losses/variance_is.hinclude/tiny-cuda-nn/optimizers/adam.hinclude/tiny-cuda-nn/optimizers/lookahead.hinclude/tiny-cuda-nn/optimizers/sgd.hinclude/tiny-cuda-nn/optimizers/shampoo.hinclude/tiny-cuda-nn/optimizers/average.hinclude/tiny-cuda-nn/optimizers/batched.hinclude/tiny-cuda-nn/optimizers/composite.hinclude/tiny-cuda-nn/optimizers/average.hinclude/tiny-cuda-nn/optimizers/exponential_decay.hinclude/tiny-cuda-nn/optimizers/lookahead.hLicense and Citation
This framework is licensed under the BSD 3-clause license. Please see
LICENSE.txtfor details.If you use it in your research, we would appreciate a citation via
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing
Publications & Software
Among others, this framework powers the following publications:
As well as the following software:
Please feel free to make a pull request if your publication or software is not listed.
Acknowledgments
Special thanks go to the NRC authors for helpful discussions and to Nikolaus Binder for providing part of the infrastructure of this framework, as well as for help with utilizing TensorCores from within CUDA.