| from .. import config | |
| import importlib | |
| import torch | |
| import torch.nn as nn | |
| from .. import SparseTensor | |
| _backends = {} | |
| class SparseConv3d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): | |
| super(SparseConv3d, self).__init__() | |
| if config.CONV not in _backends: | |
| _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) | |
| _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) | |
| def forward(self, x: SparseTensor) -> SparseTensor: | |
| return _backends[config.CONV].sparse_conv3d_forward(self, x) | |
| class SparseInverseConv3d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): | |
| super(SparseInverseConv3d, self).__init__() | |
| if config.CONV not in _backends: | |
| _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) | |
| _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key) | |
| def forward(self, x: SparseTensor) -> SparseTensor: | |
| return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x) | |