mxnet.rtc.CudaModule¶
-
class
mxnet.rtc.
CudaModule
(source, options=(), exports=())[source]¶ Compile and run CUDA code from Python.
In CUDA 7.5, you need to prepend your kernel definitions with ‘extern “C”’ to avoid name mangling:
source = r''' extern "C" __global__ void axpy(const float *x, float *y, float alpha) { int i = threadIdx.x + blockIdx.x * blockDim.x; y[i] += alpha * x[i]; } ''' module = mx.rtc.CudaModule(source) func = module.get_kernel("axpy", "const float *x, float *y, float alpha") x = mx.nd.ones((10,), ctx=mx.gpu(0)) y = mx.nd.zeros((10,), ctx=mx.gpu(0)) func.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) print(y)
Starting from CUDA 8.0, you can instead export functions by name. This also allows you to use templates:
source = r''' template<typename DType> __global__ void axpy(const DType *x, DType *y, DType alpha) { int i = threadIdx.x + blockIdx.x * blockDim.x; y[i] += alpha * x[i]; } ''' module = mx.rtc.CudaModule(source, exports=['axpy<float>', 'axpy<double>']) func32 = module.get_kernel("axpy<float>", "const float *x, float *y, float alpha") x = mx.nd.ones((10,), dtype='float32', ctx=mx.gpu(0)) y = mx.nd.zeros((10,), dtype='float32', ctx=mx.gpu(0)) func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) print(y) func64 = module.get_kernel("axpy<double>", "const double *x, double *y, double alpha") x = mx.nd.ones((10,), dtype='float64', ctx=mx.gpu(0)) y = mx.nd.zeros((10,), dtype='float64', ctx=mx.gpu(0)) func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) print(y)
- Parameters
source (str) – Complete source code.
options (tuple of str) – Compiler flags. For example, use “-I/usr/local/cuda/include” to add cuda headers to include path.
exports (tuple of str) – Export kernel names.
-
__init__
(source, options=(), exports=())[source]¶ Initialize self. See help(type(self)) for accurate signature.
Methods
__init__
(source[, options, exports])Initialize self.
get_kernel
(name, signature)Get CUDA kernel from compiled module.