mirror of
https://github.com/NVIDIA/cuda-samples.git
synced 2024-11-24 19:59:17 +08:00
Use CTK-provided type for cuTensorMapEncodeTiled
This commit is contained in:
parent
5925483b33
commit
a3b5b817e3
|
@ -32,9 +32,9 @@
|
|||
* - Create a TensorMap (TMA descriptor)
|
||||
* - Load a 2D tile of data into shared memory
|
||||
*
|
||||
* Compile with:
|
||||
* Compile and run with:
|
||||
*
|
||||
* nvcc -arch sm_90 globalToShmemTMACopy.cu -o globalToShmemTMACopy
|
||||
* nvcc -arch sm_90 -run globalToShmemTMACopy.cu
|
||||
*
|
||||
* It can be that the compiler issues the following note. This can be safely ignored.
|
||||
*
|
||||
|
@ -42,10 +42,12 @@
|
|||
* GCC 4.6
|
||||
*
|
||||
*/
|
||||
#include <cstdio> // fprintf
|
||||
#include <cstdio> // printf
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include <cuda.h> // CUtensorMap
|
||||
#include <cudaTypedefs.h> // PFN_cuTensorMapEncodeTiled
|
||||
#include <cuda.h> // CUtensormap
|
||||
|
||||
#include <cuda_awbarrier_primitives.h> // __mbarrier_*
|
||||
|
||||
#include "util.h" // CUDA_CHECK macro
|
||||
|
@ -63,19 +65,11 @@ constexpr int SMEM_H = 8; // Height of shared memory buffer (in # elements)
|
|||
* CUDA Driver API
|
||||
*/
|
||||
|
||||
// The type of the cuTensorMapEncodeTiled function.
|
||||
using cuTensorMapEncodeTiled_t = decltype(cuTensorMapEncodeTiled);
|
||||
|
||||
// Get function pointer to driver API cuTensorMapEncodeTiled
|
||||
cuTensorMapEncodeTiled_t * get_cuTensorMapEncodeTiled() {
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
|
||||
void* cuda_ptr = nullptr;
|
||||
unsigned long long flags = cudaEnableDefault;
|
||||
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER__ENTRY__POINT.html
|
||||
CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuda_ptr, flags));
|
||||
|
||||
return reinterpret_cast<cuTensorMapEncodeTiled_t*>(cuda_ptr);
|
||||
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
|
||||
void* driver_ptr = nullptr;
|
||||
cudaDriverEntryPointQueryResult driver_status;
|
||||
CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status));
|
||||
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
Loading…
Reference in New Issue
Block a user