TL; DR
Let's delve into the Instant Neural Graphics Primitive with a Multi-Resolution Hash Encoding, and re-implement this with PyTorch!
1.Introduction
Neural Radiance Fields (NeRF) are a powerful method for 3D scene reconstruction, but they come with significant drawbacks, primarily in terms of slow training and rendering speeds. To address these issues, various studies have explored voxel-based approaches. While these methods can reduce computation time, they often suffer from limited speed improvements or performance trade-offs.
Instant Neural Graphics Primitives (Instant-NGP) offers a breakthrough by utilizing multi-resolution decomposition and hashing, achieving state-of-the-art performance with remarkable speed.
In this article, I review the core of Instant-NGP and provide a PyTorch implementation of its core components.
-
project: link
2. Background
Positional Encoding
For high-fidelity scene reconstruction, NeRF typically uses sinusoidal positional encoding: $$ \gamma(p) = \big (\sin(2^0 \pi p), \cos(2^0 \pi p), \dots, \sin(2^{L-1} \pi p), \cos(2^{L-1} \pi p) \big) $$
There are alternative encodings, such as Integrated Positional Encoding (IPE) in Mip-NeRF, but the fundamental principle remains the same: the information is encoded according to different frequencies.
However, NeRF requires the inference of MLPs—typically 8 layers with 256 or 512 hidden dimensions—for every point in the rendering process. This is one of the primary reasons for NeRF's slow speed.
Voxel-based Methods
One of the main approaches to address these drawbacks is to reduce the computational burden of inference and training by pre-computing and storing data at a few key locations.
This involves:
- Learning a parametric encoding for the vertices of a 3D voxel grid by introducing learnable parameters, rather than using a fixed positional encoding.
- Using linear interpolation to approximate points between vertices, thereby improving speed (as shown in Plenoxels (CVPR 2022)).
However, voxel-based methods have the disadvantage of requiring significantly more memory compared to NeRF, and they often involve complex training processes, including various regularization techniques.
3. Method
Overview
Instant-NGP uses a similar approach to existing voxel-based methods by mapping parametric encodings to the vertices of a voxel. However, it introduces several key differences:
- Multi-level Decomposition
The scene is divided into multiple levels, with each level storing information that focuses on different parts of the scene geometry. - Hash Function
As the resolution of voxels increases, the number of points that need to be stored grows cubically. Instead of storing all points on a one-to-one basis, a hash function is used to reduce the memory required.
The following figure visualizes the forward process of Multi-Resolution Hash Encoding:
- Each voxel's vertex at different resolutions (red and blue) is stored in a table with a learnable feature vector of dimension $F$. The table-to-vertex mapping is defined through a hash function over the vertex coordinates.
- For any point in space, its encoding is determined by a linear interpolation between the features of all corner vertices of the hypercube to which the point belongs.
- This interpolated value is combined with the view-direction encoding and used as input to the decoding network $m(\mathbf{y}; \phi)$.
Instant-NGP maximizes the capabilities of parametric encoding and multi-level decomposition, allowing for an extremely shallow decoding network—typically a 2-layer network with 64 hidden dimensions. This leads to much faster point-wise inference and convergence compared to other NeRF models while still achieving SOTA performance.
The next step involves volume rendering using ray casting, similar to other NeRF-like models.
3.1. Multi-Level Decomposition
For a total of $L$ levels, the resolution $N_{l}$ of a voxel at level $l$ is determined as a value between $[N_{\text{min}}, N_{\text{max}}]$, defined as follows:
To optimize memory usage, rather than declaring a feature table that directly corresponds to each voxel resolution $N_l$, a fixed-size feature table of size $T$ is declared. If the grid size is smaller than $T$, a feature table matching the voxel size is declared to maintain a one-to-one correspondence.
In the following PyTorch custom implementation, the per-level scale $b$ is calculated using the formula above, and the feature tables are initialized accordingly based on whether the voxel size is smaller than $T$ or not.
self.one2one = []
self.units = []
for i in range(self.n_levels):
grid_size = (self.units[i]+1) ** 3
hash_size = 2 ** self.log2_hashmap_size # T in Eqn
self.one2one.append(grid_size < hash_size)
self.units.append(int(np.round(self.N_min * (self.per_level_scale ** i))))
table_size = (self.one2one[i]) * grid_size + (not self.one2one[i]) * hash_size
torch_hash = nn.Embedding(int(table_size), self.feat_dim) # self.feat_dim : F in Eqn
nn.init.uniform_(torch_hash.weight, -self.init_std, self.init_std)
setattr(self, f'torch_hash_{i}', torch_hash)
- The
self.one2one
array indicates which levels have a one-to-one correspondence. self.units
stores the voxel size per level.
3.2. Hash Grids Encoding
For encoding a point $\mathbf{x} \in \mathbb{R}^{d}$ at each level $l$, the point is first mapped onto a hypercube of size 1 at each level:
This places the point within a hypercube defined by its diagonal vertices $\lfloor \mathbf{x}_{l} \rfloor$ and $\lceil \mathbf{x}_{l} \rceil$.
Subsequently, this hypercube is mapped to the feature table using a hash function:
where $\pi_i$ are large prime numbers (e.g., $[1, 2 654 435 761, 805 459 861]$).
After the feature mapping for all $2^d$ vertices is completed, the relative positions within the hypercube are used to interpolate each vertex feature, resulting in the final encoding for level $l$.
Hash Grids & Tri-linear Interpolation
Assume the forward process of Instant-NGP receives $N$ points as input. For a typical NeRF dataset, these points are 3D, so the input shape will be $[N,\ 3]$.
Our goal is to compute:
- The $2^d$ level-wise corner vertex coordinates of the points $\mathbf{x}$ (i.e., total $l \times 2^d$ vertices).
- The level-wise trilinear interpolation weights for these points.
Let's implement this step by step!
-
First, for a given level $l$, distribute the points $\mathbf{x}$ over voxels with grid size $N_l$ and calculate corner vertices by adding offsets ($[0,0,0] \sim [1,1,1]$) to $\lfloor \mathbf{x}_{l} \rfloor$.
corners = [] N_level = self.units[l] # N_min to N_max resolution for i in range(2 ** x.size(-1)): # for 2^3 corners x_l = torch.floor(N_level * x) offsets = [int(x) for x in list('{0:03b}'.format(i))] for c in range(x.size(-1)): x_l[..., c] = x_l[..., c] + offsets[c] # 3-dim (x,y,z) corners.append(x_l) corners = torch.stack(corners, dim=-2)
-
Next, compute trilinear weights using the relative position differences between corners and $\mathbf{x}_l$.
# get trilinear weights x_ = x.unsqueeze(1) * N_level weights = (1 - torch.abs(x_ - corners)).prod(dim=-1, keepdim=True) + self.eps
These processes can be wrapped in a following function.
def hash_grids(self, x):
# input: x [N, 3]
# output:
# level_wise_corners: [L, N, 8, 3]
# level_wise_weights: [N, 8, L, 1]
corners_all = []
weights_all = []
for l in range(self.n_levels):
# get level-wise grid corners
corners = []
weights = []
N_level = self.units[l] # N_min to N_max resolution
for i in range(2 ** x.size(-1)): # 2^3 corners
x_l = torch.floor(N_level * x)
offsets = [int(x) for x in list('{0:03b}'.format(i))]
for c in range(x.size(-1)):
x_l[..., c] = x_l[..., c] + offsets[c] # 3-dim (x,y,z)
corners.append(x_l)
corners = torch.stack(corners, dim=-2) # [N, 8, 3]
# get trilinear weights
x_ = x.unsqueeze(1) * N_level # [N, 1, 3]
weights = (1 - torch.abs(x_ - corners)).prod(dim=-1, keepdim=True) + self.eps # [N, 8, 1]
corners_all.append(corners)
weights_all.append(weights)
corners_all = torch.stack(corners_all, dim=0) # [L, N, 8, 3]
weights_all = torch.stack(weights_all, dim=-2) # [N, 8, L, 1]
weights_all = weights_all / weights_all.sum(dim=-3, keepdim=True)
return corners_all, weights_all
Hash Table Mapping
The method for table mapping varies depending on whether there is a one-to-one correspondence. Using self.one2one
declared in 3.1, we handle the two cases:
-
For one-to-one correspondence, the index is directly derived from the coordinates.
for l in range(self.n_levels): ids = [] c_ = c[l].view(c[l].size(0) * c[l].size(1), c[l].size(2)) c_ = c_.int() if self.one2one[l]: # grid_size << hash_size ids = c_[:, 0] + (self.units[l] * c_[:, 1]) + ((self.units[l] ** 2) * c_[:, 2]) ids %= (self.units[l] ** 3)
-
Otherwise, the index is calculated using the hash function defined in 3.2.
# cf. self.primes = [1, 2654435761, 805459861] else: ids = (c_[:, 0] * self.primes[0]) ^ (c_[:, 1] * self.primes[1]) ^ (c_[:, 2] * self.primes[2]) ids %= (2 ** self.log2_hashmap_size)
The entire mapping process can also be wrapped into a single function.
def table_mapping(self, c):
# input: 8 corners [L, N, 8, 3]
# output: hash index [L, N * 8]
ids_all = []
with torch.no_grad():
for l in range(self.n_levels):
ids = []
c_ = c[l].view(c[l].size(0) * c[l].size(1), c[l].size(2))
c_ = c_.int()
if self.one2one[l]: # grid_size << hash_size
ids = c_[:, 0] + (self.units[l] * c_[:, 1]) + ((self.units[l] ** 2) * c_[:, 2])
ids %= (self.units[l] ** 3)
else:
ids = (c_[:, 0] * self.primes[0]) ^ (c_[:, 1] * self.primes[1]) ^ (c_[:, 2] * self.primes[2])
ids %= (2 ** self.log2_hashmap_size)
ids_all.append(ids)
return ids_all # [L * [N*8]]
3.3. Multi-Resolution Hash Encoding
We index the feature table to get the feature values for each level declared as nn.Embedding
, perform trilinear interpolation, and then concatenate them by level to obtain the final encoding.
def hash_enc(self, corners, weights):
# input: corners [L, N, 8, 3]
# weights [L, N, 8, 1]
# output: interpolated embeddings [N, L*F]
level_embedd_all = []
ids_all = self.table_mapping(corners) # [L * [N*8]]
for l in range(self.n_levels):
level_embedd = []
hash_table = (getattr(self, f'torch_hash_{l}'))
hash_table.to(corners.device)
level_embedd = hash_table(ids_all[l]) # [N*8, 1] -> [N*8, F]
level_embedd = level_embedd.view(corners.size(1), corners.size(2), self.feat_dim) # [N, 8, F]
level_embedd_all.append(level_embedd)
# Trilinear Interpolation
# weights: [N, 8, L, 1]
level_embedd_all = torch.stack(level_embedd_all, dim = -2) # [N, 8, L, F]
level_embedd_all = torch.sum(weights * level_embedd_all, dim=-3) # [N, L, F]
return level_embedd_all.reshape(weights.size(0), self.n_levels * self.feat_dim)
For input $\mathbf{x}$ of shape $[N,\ 3]$, we obtain the multi-resolution hash encoding result.
corners_all, weights_all = self.hash_grids(x)
encodings = self.hash_enc(corners_all, weights_all)
Closing
The implementation above demonstrates that by matching the input dimension size, the code can be compatible with any NeRF-like model decoding network.
This flexibility allows us to combine other NeRF models with Multi-Resolution Hash Encoding using this code easily.
However, the implementation may not be as fast as the original Instant-NGP due to several reasons:
- The PyTorch implementation, unlike the original CUDA/C++ version, incurs additional execution time.
- Instant-NGP utilizes the tcnn library for the decoding network, further optimizing inference speed.
- There are additional implementation details, such as pruning hypercubes without opaque particles to improve inference efficiency.
You may also like,