JeffreyXiang's picture
Finalize
a1e3f5f
from .. import config
import importlib
import torch
import torch.nn as nn
from .. import SparseTensor
_backends = {}
class SparseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
super(SparseConv3d, self).__init__()
if config.CONV not in _backends:
_backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
_backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
def forward(self, x: SparseTensor) -> SparseTensor:
return _backends[config.CONV].sparse_conv3d_forward(self, x)
class SparseInverseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
super(SparseInverseConv3d, self).__init__()
if config.CONV not in _backends:
_backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
_backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key)
def forward(self, x: SparseTensor) -> SparseTensor:
return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x)