TL; DR
Let's delve into the Instant Neural Graphics Primitive with a Multi-Resolution Hash Encoding, and re-implement this with PyTorch!
 
                            1.Introduction
Neural Radiance Fields (NeRF) are a powerful method for 3D scene reconstruction, but they come with significant drawbacks, primarily in terms of slow training and rendering speeds. To address these issues, various studies have explored voxel-based approaches. While these methods can reduce computation time, they often suffer from limited speed improvements or performance trade-offs.
Instant Neural Graphics Primitives (Instant-NGP) offers a breakthrough by utilizing multi-resolution decomposition and hashing, achieving state-of-the-art performance with remarkable speed.
In this article, I review the core of Instant-NGP and provide a PyTorch implementation of its core components.
- 
                                project: link 
2. Background
Positional Encoding
For high-fidelity scene reconstruction, NeRF typically uses sinusoidal positional encoding: $$ \gamma(p) = \big (\sin(2^0 \pi p), \cos(2^0 \pi p), \dots, \sin(2^{L-1} \pi p), \cos(2^{L-1} \pi p) \big) $$
There are alternative encodings, such as Integrated Positional Encoding (IPE) in Mip-NeRF, but the fundamental principle remains the same: the information is encoded according to different frequencies.
However, NeRF requires the inference of MLPs—typically 8 layers with 256 or 512 hidden dimensions—for every point in the rendering process. This is one of the primary reasons for NeRF's slow speed.
Voxel-based Methods
One of the main approaches to address these drawbacks is to reduce the computational burden of inference and training by pre-computing and storing data at a few key locations.
This involves:
- Learning a parametric encoding for the vertices of a 3D voxel grid by introducing learnable parameters, rather than using a fixed positional encoding.
- Using linear interpolation to approximate points between vertices, thereby improving speed (as shown in Plenoxels (CVPR 2022)).
 
                            However, voxel-based methods have the disadvantage of requiring significantly more memory compared to NeRF, and they often involve complex training processes, including various regularization techniques.
3. Method
Overview
Instant-NGP uses a similar approach to existing voxel-based methods by mapping parametric encodings to the vertices of a voxel. However, it introduces several key differences:
- Multi-level Decomposition
 The scene is divided into multiple levels, with each level storing information that focuses on different parts of the scene geometry.
- Hash Function
 As the resolution of voxels increases, the number of points that need to be stored grows cubically. Instead of storing all points on a one-to-one basis, a hash function is used to reduce the memory required.
The following figure visualizes the forward process of Multi-Resolution Hash Encoding:
 
                            - Each voxel's vertex at different resolutions (red and blue) is stored in a table with a learnable feature vector of dimension $F$. The table-to-vertex mapping is defined through a hash function over the vertex coordinates.
- For any point in space, its encoding is determined by a linear interpolation between the features of all corner vertices of the hypercube to which the point belongs.
- This interpolated value is combined with the view-direction encoding and used as input to the decoding network $m(\mathbf{y}; \phi)$.
Instant-NGP maximizes the capabilities of parametric encoding and multi-level decomposition, allowing for an extremely shallow decoding network—typically a 2-layer network with 64 hidden dimensions. This leads to much faster point-wise inference and convergence compared to other NeRF models while still achieving SOTA performance.
The next step involves volume rendering using ray casting, similar to other NeRF-like models.
3.1. Multi-Level Decomposition
For a total of $L$ levels, the resolution $N_{l}$ of a voxel at level $l$ is determined as a value between $[N_{\text{min}}, N_{\text{max}}]$, defined as follows:
To optimize memory usage, rather than declaring a feature table that directly corresponds to each voxel resolution $N_l$, a fixed-size feature table of size $T$ is declared. If the grid size is smaller than $T$, a feature table matching the voxel size is declared to maintain a one-to-one correspondence.
In the following PyTorch custom implementation, the per-level scale $b$ is calculated using the formula above, and the feature tables are initialized accordingly based on whether the voxel size is smaller than $T$ or not.
self.one2one = []
self.units = []
for i in range(self.n_levels):
    grid_size = (self.units[i]+1) ** 3
    hash_size = 2 ** self.log2_hashmap_size # T in Eqn
    self.one2one.append(grid_size < hash_size)
    self.units.append(int(np.round(self.N_min * (self.per_level_scale ** i))))
    table_size = (self.one2one[i]) * grid_size + (not self.one2one[i]) * hash_size
    torch_hash = nn.Embedding(int(table_size), self.feat_dim) # self.feat_dim : F in Eqn
    nn.init.uniform_(torch_hash.weight, -self.init_std, self.init_std)
    setattr(self, f'torch_hash_{i}', torch_hash)- The self.one2onearray indicates which levels have a one-to-one correspondence.
- self.unitsstores the voxel size per level.
3.2. Hash Grids Encoding
For encoding a point $\mathbf{x} \in \mathbb{R}^{d}$ at each level $l$, the point is first mapped onto a hypercube of size 1 at each level:
This places the point within a hypercube defined by its diagonal vertices $\lfloor \mathbf{x}_{l} \rfloor$ and $\lceil \mathbf{x}_{l} \rceil$.
Subsequently, this hypercube is mapped to the feature table using a hash function:
where $\pi_i$ are large prime numbers (e.g., $[1, 2 654 435 761, 805 459 861]$).
After the feature mapping for all $2^d$ vertices is completed, the relative positions within the hypercube are used to interpolate each vertex feature, resulting in the final encoding for level $l$.
Hash Grids & Tri-linear Interpolation
Assume the forward process of Instant-NGP receives $N$ points as input. For a typical NeRF dataset, these points are 3D, so the input shape will be $[N,\ 3]$.
Our goal is to compute:
- The $2^d$ level-wise corner vertex coordinates of the points $\mathbf{x}$ (i.e., total $l \times 2^d$ vertices).
- The level-wise trilinear interpolation weights for these points.
Let's implement this step by step!
- 
                            First, for a given level $l$, distribute the points $\mathbf{x}$ over voxels with grid size $N_l$ and calculate corner vertices by adding offsets ($[0,0,0] \sim [1,1,1]$) to $\lfloor \mathbf{x}_{l} \rfloor$.
                            corners = [] N_level = self.units[l] # N_min to N_max resolution for i in range(2 ** x.size(-1)): # for 2^3 corners x_l = torch.floor(N_level * x) offsets = [int(x) for x in list('{0:03b}'.format(i))] for c in range(x.size(-1)): x_l[..., c] = x_l[..., c] + offsets[c] # 3-dim (x,y,z) corners.append(x_l) corners = torch.stack(corners, dim=-2)
- 
                            Next, compute trilinear weights using the relative position differences between corners and $\mathbf{x}_l$.
                            # get trilinear weights x_ = x.unsqueeze(1) * N_level weights = (1 - torch.abs(x_ - corners)).prod(dim=-1, keepdim=True) + self.eps
These processes can be wrapped in a following function.
def hash_grids(self, x):
# input: x [N, 3]
# output: 
#   level_wise_corners: [L, N, 8, 3]
#   level_wise_weights: [N, 8, L, 1]
corners_all = []
weights_all = []
for l in range(self.n_levels):
    # get level-wise grid corners 
    corners = []
    weights = []
    
    N_level = self.units[l] # N_min to N_max resolution 
    for i in range(2 ** x.size(-1)): # 2^3 corners 
        x_l = torch.floor(N_level * x)
        offsets = [int(x) for x in list('{0:03b}'.format(i))]
        for c in range(x.size(-1)):
            x_l[..., c] = x_l[..., c] + offsets[c] # 3-dim (x,y,z)
        corners.append(x_l)
    corners = torch.stack(corners, dim=-2) # [N, 8, 3]
    # get trilinear weights 
    x_ = x.unsqueeze(1) * N_level # [N, 1, 3]
    weights = (1 - torch.abs(x_ - corners)).prod(dim=-1, keepdim=True) + self.eps # [N, 8, 1]
    corners_all.append(corners)
    weights_all.append(weights)
corners_all = torch.stack(corners_all, dim=0) # [L, N, 8, 3]
weights_all = torch.stack(weights_all, dim=-2) # [N, 8, L, 1]
weights_all = weights_all / weights_all.sum(dim=-3, keepdim=True)
return corners_all, weights_allHash Table Mapping
The method for table mapping varies depending on whether there is a one-to-one correspondence. Using self.one2one declared in 3.1, we handle the two cases:
- 
                                For one-to-one correspondence, the index is directly derived from the coordinates. for l in range(self.n_levels): ids = [] c_ = c[l].view(c[l].size(0) * c[l].size(1), c[l].size(2)) c_ = c_.int() if self.one2one[l]: # grid_size << hash_size ids = c_[:, 0] + (self.units[l] * c_[:, 1]) + ((self.units[l] ** 2) * c_[:, 2]) ids %= (self.units[l] ** 3)
- 
                                Otherwise, the index is calculated using the hash function defined in 3.2. # cf. self.primes = [1, 2654435761, 805459861] else: ids = (c_[:, 0] * self.primes[0]) ^ (c_[:, 1] * self.primes[1]) ^ (c_[:, 2] * self.primes[2]) ids %= (2 ** self.log2_hashmap_size)
The entire mapping process can also be wrapped into a single function.
def table_mapping(self, c):
	# input: 8 corners [L, N, 8, 3]
	# output: hash index [L, N * 8]
	
	ids_all = []
	with torch.no_grad():
	    for l in range(self.n_levels):
        	ids = []
	        
            c_ = c[l].view(c[l].size(0) * c[l].size(1), c[l].size(2)) 
	        c_ = c_.int()  
	        
	        if self.one2one[l]: # grid_size << hash_size 
	            ids = c_[:, 0] + (self.units[l] * c_[:, 1]) + ((self.units[l] ** 2) * c_[:, 2])
	            ids %= (self.units[l] ** 3)
	        else: 
	            ids = (c_[:, 0] * self.primes[0]) ^ (c_[:, 1] * self.primes[1]) ^ (c_[:, 2] * self.primes[2])
	            ids %= (2 ** self.log2_hashmap_size)
	        ids_all.append(ids) 
	
	return ids_all # [L * [N*8]]3.3. Multi-Resolution Hash Encoding
We index the feature table to get the feature values for each level declared as nn.Embedding, perform trilinear interpolation, and then concatenate them by level to obtain the final encoding.
def hash_enc(self, corners, weights):
	# input:    corners [L, N, 8, 3]
	#           weights [L, N, 8, 1]
	# output: interpolated embeddings [N, L*F]
	
	level_embedd_all = []
	ids_all = self.table_mapping(corners) # [L * [N*8]]
	
	for l in range(self.n_levels):
	    level_embedd = []
	    hash_table = (getattr(self, f'torch_hash_{l}'))
	    hash_table.to(corners.device)
	
	    level_embedd = hash_table(ids_all[l]) # [N*8, 1] -> [N*8, F] 
        
	    level_embedd = level_embedd.view(corners.size(1), corners.size(2), self.feat_dim) # [N, 8, F]
	    level_embedd_all.append(level_embedd)
	
	# Trilinear Interpolation 
	# weights: [N, 8, L, 1]
	level_embedd_all = torch.stack(level_embedd_all, dim = -2) # [N, 8, L, F]
	level_embedd_all = torch.sum(weights * level_embedd_all, dim=-3) # [N, L, F]
	
	return level_embedd_all.reshape(weights.size(0), self.n_levels * self.feat_dim)For input $\mathbf{x}$ of shape $[N,\ 3]$, we obtain the multi-resolution hash encoding result.
corners_all, weights_all = self.hash_grids(x)
encodings = self.hash_enc(corners_all, weights_all)Closing
The implementation above demonstrates that by matching the input dimension size, the code can be compatible with any NeRF-like model decoding network.
This flexibility allows us to combine other NeRF models with Multi-Resolution Hash Encoding using this code easily.
However, the implementation may not be as fast as the original Instant-NGP due to several reasons:
- The PyTorch implementation, unlike the original CUDA/C++ version, incurs additional execution time.
- Instant-NGP utilizes the tcnn library for the decoding network, further optimizing inference speed.
- There are additional implementation details, such as pruning hypercubes without opaque particles to improve inference efficiency.
You may also like,