Merge branch 'shawnz_bugs_fix_cuda_a_dev' into 'cuda_a_dev'

Change for fixing Bug 5196362, 5184356, 5212196, 5214258 and 5214259

See merge request cuda-samples/cuda-samples!102
This commit is contained in:
Rob Armstrong 2025-04-11 07:07:55 -07:00
commit bded2585a4
31 changed files with 191 additions and 120 deletions

View File

@ -187,6 +187,7 @@ CUmodule loadCUBIN(char *cubin, int argc, char **argv) {
CUcontext context;
int major = 0, minor = 0;
char deviceName[256];
CUctxCreateParams ctxCreateParams = {};
// Picks the best CUDA device available
CUdevice cuDevice = findCudaDeviceDRV(argc, (const char **)argv);
@ -200,7 +201,7 @@ CUmodule loadCUBIN(char *cubin, int argc, char **argv) {
printf("> GPU Device has SM %d.%d compute capability\n", major, minor);
checkCudaErrors(cuInit(0));
checkCudaErrors(cuCtxCreate(&context, 0, cuDevice));
checkCudaErrors(cuCtxCreate(&context, &ctxCreateParams, 0, cuDevice));
checkCudaErrors(cuModuleLoadData(&module, cubin));
free(cubin);

View File

@ -247,7 +247,9 @@ int main(int argc, char **argv)
exit(EXIT_WAIVED);
}
if (device_prop.computeMode == cudaComputeModeProhibited) {
int computeMode;
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, dev_id));
if (computeMode == cudaComputeModeProhibited) {
// This sample requires being run with a default or process exclusive mode
fprintf(stderr,
"This sample requires a device in either default or process "

View File

@ -271,6 +271,7 @@ static int initCUDA(int argc, char **argv, CUfunction *pMatrixMul, int *blk_size
CUfunction cuFunction = 0;
int major = 0, minor = 0;
char deviceName[100];
CUctxCreateParams ctxCreateParams = {};
cuDevice = findCudaDeviceDRV(argc, (const char **)argv);
@ -283,7 +284,7 @@ static int initCUDA(int argc, char **argv, CUfunction *pMatrixMul, int *blk_size
checkCudaErrors(cuDeviceTotalMem(&totalGlobalMem, cuDevice));
printf(" Total amount of global memory: %llu bytes\n", (long long unsigned int)totalGlobalMem);
checkCudaErrors(cuCtxCreate(&cuContext, 0, cuDevice));
checkCudaErrors(cuCtxCreate(&cuContext, &ctxCreateParams, 0, cuDevice));
// first search for the module path before we load the results
std::string module_path;

View File

@ -86,13 +86,14 @@ int main(int argc, char **argv)
CUfunction vecAdd_kernel;
CUmodule cuModule = 0;
CUcontext cuContext;
CUctxCreateParams ctxCreateParams = {};
// Initialize
checkCudaDrvErrors(cuInit(0));
cuDevice = findCudaDevice(argc, (const char **)argv);
// Create context
checkCudaDrvErrors(cuCtxCreate(&cuContext, 0, cuDevice));
checkCudaDrvErrors(cuCtxCreate(&cuContext, &ctxCreateParams, 0, cuDevice));
// first search for the module path before we load the results
string module_path;

View File

@ -127,6 +127,10 @@ int main(int argc, char **argv)
checkCudaErrors(cudaGetDevice(&cuda_device));
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, cuda_device));
// Get device clock rate
int clockRate;
checkCudaErrors(cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, cuda_device));
// HyperQ is available in devices of Compute Capability 3.5 and higher
if (deviceProp.major < 3 || (deviceProp.major == 3 && deviceProp.minor < 5)) {
if (deviceProp.concurrentKernels == 0) {
@ -170,9 +174,9 @@ int main(int argc, char **argv)
#if defined(__arm__) || defined(__aarch64__)
// the kernel takes more time than the channel reset time on arm archs, so to
// prevent hangs reduce time_clocks.
clock_t time_clocks = (clock_t)(kernel_time * (deviceProp.clockRate / 100));
clock_t time_clocks = (clock_t)(kernel_time * (clockRate / 100));
#else
clock_t time_clocks = (clock_t)(kernel_time * deviceProp.clockRate);
clock_t time_clocks = (clock_t)(kernel_time * clockRate);
#endif
clock_t total_clocks = 0;

View File

@ -247,7 +247,9 @@ static void parentProcess(char *app)
}
// This sample requires two processes accessing each device, so we need
// to ensure exclusive or prohibited mode is not set
if (prop.computeMode != cudaComputeModeDefault) {
int computeMode;
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, i));
if (computeMode != cudaComputeModeDefault) {
printf("Device %d is in an unsupported compute mode for this sample\n", i);
continue;
}

View File

@ -218,9 +218,11 @@ int main(int argc, char *argv[])
printf("\n");
printf("Relevant properties of this CUDA device\n");
int canOverlap;
checkCudaErrors(cudaDeviceGetAttribute(&canOverlap, cudaDevAttrGpuOverlap, cuda_device));
printf("(%s) Can overlap one CPU<>GPU data transfer with GPU kernel execution "
"(device property \"deviceOverlap\")\n",
deviceProp.deviceOverlap ? "X" : " ");
"(device property \"cudaDevAttrGpuOverlap\")\n",
canOverlap ? "X" : " ");
// printf("(%s) Can execute several GPU kernels simultaneously (compute
// capability >= 2.0)\n", deviceProp.major >= 2 ? "X": " ");
printf("(%s) Can overlap two CPU<>GPU data transfers with GPU kernel execution\n"

View File

@ -313,6 +313,7 @@ static CUresult initCUDA(int argc, char **argv, CUfunction *transform)
int major = 0, minor = 0, devID = 0;
char deviceName[100];
string module_path;
CUctxCreateParams ctxCreateParams = {};
cuDevice = findCudaDeviceDRV(argc, (const char **)argv);
@ -322,7 +323,7 @@ static CUresult initCUDA(int argc, char **argv, CUfunction *transform)
checkCudaErrors(cuDeviceGetName(deviceName, sizeof(deviceName), cuDevice));
printf("> GPU Device has SM %d.%d compute capability\n", major, minor);
checkCudaErrors(cuCtxCreate(&cuContext, 0, cuDevice));
checkCudaErrors(cuCtxCreate(&cuContext, &ctxCreateParams, 0, cuDevice));
// first search for the module_path before we try to load the results
std::ostringstream fatbin;

View File

@ -287,7 +287,9 @@ int main(int argc, char **argv)
exit(EXIT_WAIVED);
}
if (device_prop.computeMode == cudaComputeModeProhibited) {
int computeMode;
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, dev_id));
if (computeMode == cudaComputeModeProhibited) {
// This sample requires being run with a default or process exclusive mode
fprintf(stderr,
"This sample requires a device in either default or process "

View File

@ -77,13 +77,14 @@ int main(int argc, char **argv)
printf("Vector Addition (Driver API)\n");
int N = 50000, devID = 0;
size_t size = N * sizeof(float);
CUctxCreateParams ctxCreateParams = {};
// Initialize
checkCudaErrors(cuInit(0));
cuDevice = findCudaDeviceDRV(argc, (const char **)argv);
// Create context
checkCudaErrors(cuCtxCreate(&cuContext, 0, cuDevice));
checkCudaErrors(cuCtxCreate(&cuContext, &ctxCreateParams, 0, cuDevice));
// first search for the module path before we load the results
string module_path;

View File

@ -59,6 +59,7 @@ template <class T> inline void getCudaAttribute(T *attribute, CUdevice_attribute
#endif /* CUDART_VERSION < 5000 */
////////////////////////////////////////////////////////////////////////////////
// Program main
////////////////////////////////////////////////////////////////////////////////
@ -128,14 +129,20 @@ int main(int argc, char **argv)
deviceProp.multiProcessorCount,
_ConvertSMVer2Cores(deviceProp.major, deviceProp.minor),
_ConvertSMVer2Cores(deviceProp.major, deviceProp.minor) * deviceProp.multiProcessorCount);
int clockRate;
checkCudaErrors(cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, dev));
printf(" GPU Max Clock rate: %.0f MHz (%0.2f "
"GHz)\n",
deviceProp.clockRate * 1e-3f,
deviceProp.clockRate * 1e-6f);
clockRate * 1e-3f,
clockRate * 1e-6f);
#if CUDART_VERSION >= 5000
// This is supported in CUDA 5.0 (runtime API device properties)
printf(" Memory Clock rate: %.0f Mhz\n", deviceProp.memoryClockRate * 1e-3f);
int memoryClockRate;
#if CUDART_VERSION >= 13000
checkCudaErrors(cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, dev));
#else
memoryClockRate = deviceProp.memoryClockRate;
#endif
printf(" Memory Clock rate: %.0f Mhz\n", memoryClockRate * 1e-3f);
printf(" Memory Bus Width: %d-bit\n", deviceProp.memoryBusWidth);
if (deviceProp.l2CacheSize) {
@ -194,12 +201,15 @@ int main(int argc, char **argv)
deviceProp.maxGridSize[2]);
printf(" Maximum memory pitch: %zu bytes\n", deviceProp.memPitch);
printf(" Texture alignment: %zu bytes\n", deviceProp.textureAlignment);
int gpuOverlap;
checkCudaErrors(cudaDeviceGetAttribute(&gpuOverlap, cudaDevAttrGpuOverlap, dev));
printf(" Concurrent copy and kernel execution: %s with %d copy "
"engine(s)\n",
(deviceProp.deviceOverlap ? "Yes" : "No"),
(gpuOverlap ? "Yes" : "No"),
deviceProp.asyncEngineCount);
printf(" Run time limit on kernels: %s\n",
deviceProp.kernelExecTimeoutEnabled ? "Yes" : "No");
int kernelExecTimeout;
checkCudaErrors(cudaDeviceGetAttribute(&kernelExecTimeout, cudaDevAttrKernelExecTimeout, dev));
printf(" Run time limit on kernels: %s\n", kernelExecTimeout ? "Yes" : "No");
printf(" Integrated GPU sharing Host Memory: %s\n", deviceProp.integrated ? "Yes" : "No");
printf(" Support host page-locked memory mapping: %s\n", deviceProp.canMapHostMemory ? "Yes" : "No");
printf(" Alignment requirement for Surfaces: %s\n", deviceProp.surfaceAlignment ? "Yes" : "No");
@ -213,8 +223,11 @@ int main(int argc, char **argv)
printf(" Device supports Compute Preemption: %s\n",
deviceProp.computePreemptionSupported ? "Yes" : "No");
printf(" Supports Cooperative Kernel Launch: %s\n", deviceProp.cooperativeLaunch ? "Yes" : "No");
// The property cooperativeMultiDeviceLaunch is deprecated in CUDA 13.0
#if CUDART_VERSION < 13000
printf(" Supports MultiDevice Co-op Kernel Launch: %s\n",
deviceProp.cooperativeMultiDeviceLaunch ? "Yes" : "No");
#endif
printf(" Device PCI Domain ID / Bus ID / location ID: %d / %d / %d\n",
deviceProp.pciDomainID,
deviceProp.pciBusID,
@ -230,8 +243,10 @@ int main(int argc, char **argv)
"::cudaSetDevice() with this device)",
"Unknown",
NULL};
int computeMode;
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, dev));
printf(" Compute Mode:\n");
printf(" < %s >\n", sComputeMode[deviceProp.computeMode]);
printf(" < %s >\n", sComputeMode[computeMode]);
}
// If there are 2 or more GPUs, query to determine whether RDMA is supported

View File

@ -192,6 +192,7 @@ CUresult cudaDeviceCreateConsumer(test_cuda_consumer_s *cudaConsumer)
{
CUdevice device;
CUresult status = CUDA_SUCCESS;
CUctxCreateParams ctxCreateParams = {};
if (CUDA_SUCCESS != (status = cuInit(0))) {
printf("Failed to initialize CUDA\n");
@ -203,7 +204,7 @@ CUresult cudaDeviceCreateConsumer(test_cuda_consumer_s *cudaConsumer)
return status;
}
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaConsumer->context, 0, device))) {
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaConsumer->context, &ctxCreateParams, 0, device))) {
printf("failed to create CUDA context\n");
return status;
}

View File

@ -184,6 +184,7 @@ CUresult cudaDeviceCreateProducer(test_cuda_producer_s *cudaProducer)
{
CUdevice device;
CUresult status = CUDA_SUCCESS;
CUctxCreateParams ctxCreateParams = {};
if (CUDA_SUCCESS != (status = cuInit(0))) {
printf("Failed to initialize CUDA\n");
@ -195,7 +196,7 @@ CUresult cudaDeviceCreateProducer(test_cuda_producer_s *cudaProducer)
return status;
}
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaProducer->context, 0, device))) {
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaProducer->context, &ctxCreateParams, 0, device))) {
printf("failed to create CUDA context\n");
return status;
}

View File

@ -302,7 +302,8 @@ CUresult cudaDeviceCreateConsumer(test_cuda_consumer_s *cudaConsumer, CUdevice d
major,
minor);
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaConsumer->context, 0, device))) {
CUctxCreateParams ctxCreateParams = {};
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaConsumer->context, &ctxCreateParams, 0, device))) {
printf("failed to create CUDA context\n");
return status;
}

View File

@ -316,7 +316,8 @@ CUresult cudaDeviceCreateProducer(test_cuda_producer_s *cudaProducer, CUdevice d
exit(2); // EXIT_WAIVED
}
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaProducer->context, 0, device))) {
CUctxCreateParams ctxCreateParams = {};
if (CUDA_SUCCESS != (status = cuCtxCreate(&cudaProducer->context, &ctxCreateParams, 0, device))) {
printf("failed to create CUDA context\n");
return status;
}

View File

@ -69,6 +69,9 @@
#include <thrust/sort.h>
#include <thrust/unique.h>
// for cuda::std::identity
#include <cuda/std/functional>
// Sample framework includes.
#include <helper_cuda.h>
#include <helper_functions.h>
@ -680,7 +683,7 @@ private:
thrust::make_counting_iterator(validEdgesCount),
dEdgesFlags,
dVertices_,
thrust::identity<uint>())
cuda::std::identity())
.get();
pools.uintEdges.put(dEdgesFlags);

View File

@ -322,7 +322,9 @@ static void parentProcess(char *app)
}
// This sample requires two processes accessing each device, so we need
// to ensure exclusive or prohibited mode is not set
if (prop.computeMode != cudaComputeModeDefault) {
int computeMode;
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, i));
if (computeMode != cudaComputeModeDefault) {
printf("Device %d is in an unsupported compute mode for this sample\n", i);
continue;
}

View File

@ -122,9 +122,10 @@ static CUresult InitCUDAContext(CUDAContext *pContext, CUdevice hcuDevice, int d
CUmodule hcuModule = 0;
CUfunction hcuFunction = 0;
CUdeviceptr dptr = 0;
CUctxCreateParams ctxCreateParams = {};
// cuCtxCreate: Function works on floating contexts and current context
CUresult status = cuCtxCreate(&hcuContext, 0, hcuDevice);
CUresult status = cuCtxCreate(&hcuContext, &ctxCreateParams, 0, hcuDevice);
if (CUDA_SUCCESS != status) {
fprintf(stderr, "cuCtxCreate for <deviceID=%d> failed %d\n", deviceID, status);

View File

@ -97,13 +97,13 @@ void simpleIfGraph(void)
params.kernel.kernelParams = kernelArgs;
kernelArgs[0] = &dPtr;
kernelArgs[1] = &handle;
checkCudaErrors(cudaGraphAddNode(&kernelNode, graph, NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&kernelNode, graph, NULL, NULL, 0, &params));
cudaGraphNodeParams cParams = {cudaGraphNodeTypeConditional};
cParams.conditional.handle = handle;
cParams.conditional.type = cudaGraphCondTypeIf;
cParams.conditional.size = 1;
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, &kernelNode, 1, &cParams));
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, &kernelNode, NULL, 0, &cParams));
cudaGraph_t bodyGraph = cParams.conditional.phGraph_out[0];
@ -111,7 +111,7 @@ void simpleIfGraph(void)
cudaGraphNode_t bodyNode;
params.kernel.func = (void *)ifGraphKernelC;
params.kernel.kernelParams = nullptr;
checkCudaErrors(cudaGraphAddNode(&bodyNode, bodyGraph, NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&bodyNode, bodyGraph, NULL, NULL, 0, &params));
checkCudaErrors(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
@ -182,7 +182,7 @@ void simpleDoWhileGraph(void)
cParams.conditional.handle = handle;
cParams.conditional.type = cudaGraphCondTypeWhile;
cParams.conditional.size = 1;
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, NULL, 0, &cParams));
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, NULL, NULL, 0, &cParams));
cudaGraph_t bodyGraph = cParams.conditional.phGraph_out[0];
@ -267,7 +267,8 @@ void capturedWhileGraph(void)
checkCudaErrors(cudaStreamBeginCapture(captureStream, cudaStreamCaptureModeGlobal));
// Obtain the handle of the graph
checkCudaErrors(cudaStreamGetCaptureInfo(captureStream, &status, NULL, &graph, &dependencies, &numDependencies));
checkCudaErrors(
cudaStreamGetCaptureInfo(captureStream, &status, NULL, &graph, &dependencies, NULL, &numDependencies));
// Create the conditional handle
cudaGraphConditionalHandle handle;
@ -277,7 +278,8 @@ void capturedWhileGraph(void)
capturedWhileKernel<<<1, 1, 0, captureStream>>>(dPtr, handle);
// Obtain the handle for node A
checkCudaErrors(cudaStreamGetCaptureInfo(captureStream, &status, NULL, &graph, &dependencies, &numDependencies));
checkCudaErrors(
cudaStreamGetCaptureInfo(captureStream, &status, NULL, &graph, &dependencies, NULL, &numDependencies));
// Insert conditional node B
cudaGraphNode_t conditionalNode;
@ -285,13 +287,13 @@ void capturedWhileGraph(void)
cParams.conditional.handle = handle;
cParams.conditional.type = cudaGraphCondTypeWhile;
cParams.conditional.size = 1;
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, dependencies, numDependencies, &cParams));
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, dependencies, NULL, numDependencies, &cParams));
cudaGraph_t bodyGraph = cParams.conditional.phGraph_out[0];
// Update stream capture dependencies to account for the node we manually added
checkCudaErrors(
cudaStreamUpdateCaptureDependencies(captureStream, &conditionalNode, 1, cudaStreamSetCaptureDependencies));
checkCudaErrors(cudaStreamUpdateCaptureDependencies(
captureStream, &conditionalNode, NULL, 1, cudaStreamSetCaptureDependencies));
// Insert kernel node D
capturedWhileEmptyKernel<<<1, 1, 0, captureStream>>>();
@ -380,13 +382,13 @@ void simpleIfElseGraph(void)
params.kernel.kernelParams = kernelArgs;
kernelArgs[0] = &dPtr;
kernelArgs[1] = &handle;
checkCudaErrors(cudaGraphAddNode(&kernelNode, graph, NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&kernelNode, graph, NULL, NULL, 0, &params));
cudaGraphNodeParams cParams = {cudaGraphNodeTypeConditional};
cParams.conditional.handle = handle;
cParams.conditional.type = cudaGraphCondTypeIf;
cParams.conditional.size = 2; // Set size to 2 to indicate an ELSE graph will be used
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, &kernelNode, 1, &cParams));
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, &kernelNode, NULL, 0, &cParams));
cudaGraph_t bodyGraph = cParams.conditional.phGraph_out[0];
@ -394,7 +396,7 @@ void simpleIfElseGraph(void)
cudaGraphNode_t trueBodyNode;
params.kernel.func = (void *)ifGraphKernelC;
params.kernel.kernelParams = nullptr;
checkCudaErrors(cudaGraphAddNode(&trueBodyNode, bodyGraph, NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&trueBodyNode, bodyGraph, NULL, NULL, 0, &params));
// Populate the body of the second graph in the conditional node, executed if the condition is false
bodyGraph = cParams.conditional.phGraph_out[1];
@ -402,7 +404,7 @@ void simpleIfElseGraph(void)
cudaGraphNode_t falseBodyNode;
params.kernel.func = (void *)ifGraphKernelD;
params.kernel.kernelParams = nullptr;
checkCudaErrors(cudaGraphAddNode(&falseBodyNode, bodyGraph, NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&falseBodyNode, bodyGraph, NULL, NULL, 0, &params));
checkCudaErrors(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));
@ -484,25 +486,25 @@ void simpleSwitchGraph(void)
params.kernel.kernelParams = kernelArgs;
kernelArgs[0] = &dPtr;
kernelArgs[1] = &handle;
checkCudaErrors(cudaGraphAddNode(&kernelNode, graph, NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&kernelNode, graph, NULL, NULL, 0, &params));
cudaGraphNodeParams cParams = {cudaGraphNodeTypeConditional};
cParams.conditional.handle = handle;
cParams.conditional.type = cudaGraphCondTypeSwitch;
cParams.conditional.size = 4;
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, &kernelNode, 1, &cParams));
checkCudaErrors(cudaGraphAddNode(&conditionalNode, graph, &kernelNode, NULL, 0, &cParams));
// Populate the four graph bodies within the SWITCH conditional graph
cudaGraphNode_t bodyNode;
params.kernel.kernelParams = nullptr;
params.kernel.func = (void *)switchGraphKernelC;
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[0], NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[0], NULL, NULL, 0, &params));
params.kernel.func = (void *)switchGraphKernelD;
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[1], NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[1], NULL, NULL, 0, &params));
params.kernel.func = (void *)switchGraphKernelE;
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[2], NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[2], NULL, NULL, 0, &params));
params.kernel.func = (void *)switchGraphKernelF;
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[3], NULL, 0, &params));
checkCudaErrors(cudaGraphAddNode(&bodyNode, cParams.conditional.phGraph_out[3], NULL, NULL, 0, &params));
checkCudaErrors(cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0));

View File

@ -149,9 +149,9 @@ void createSimpleAllocFreeGraph(cudaGraphExec_t *graphExec, float **dPtr, size_t
checkCudaErrors(cudaGraphAddMemAllocNode(&allocNodeA, graph, NULL, 0, &allocParams));
*dPtr = (float *)allocParams.dptr;
cudaDeviceProp deviceProp;
checkCudaErrors(cudaGetDeviceProperties(&deviceProp, device));
clock_t time_clocks = (clock_t)((kernelTime / 1000.0) * deviceProp.clockRate);
int clockRate;
checkCudaErrors(cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, device));
clock_t time_clocks = (clock_t)((kernelTime / 1000.0) * clockRate);
void *blockDeviceArgs[1] = {(void *)&time_clocks};

View File

@ -344,9 +344,10 @@ static void childProcess(int devId, int id, char **argv)
CUdevice device;
CUstream stream;
int multiProcessorCount;
CUctxCreateParams ctx_params = {};
checkCudaErrors(cuDeviceGet(&device, devId));
checkCudaErrors(cuCtxCreate(&ctx, 0, device));
checkCudaErrors(cuCtxCreate(&ctx, &ctx_params, 0, device));
checkCudaErrors(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
// Obtain kernel function for the sample
@ -519,7 +520,8 @@ static void parentProcess(char *app)
}
if (allPeers) {
CUcontext ctx;
checkCudaErrors(cuCtxCreate(&ctx, 0, devices[i]));
CUctxCreateParams ctx_params = {};
checkCudaErrors(cuCtxCreate(&ctx, &ctx_params, 0, devices[i]));
ctxs.push_back(ctx);
// Enable peers here. This isn't necessary for IPC, but it will

View File

@ -585,9 +585,12 @@ int main(int argc, char **argv)
genTridiag(I, J, val_cpu, N, nz);
memcpy(val, val_cpu, sizeof(float) * nz);
checkCudaErrors(cudaMemAdvise(I, sizeof(int) * (N + 1), cudaMemAdviseSetReadMostly, 0));
checkCudaErrors(cudaMemAdvise(J, sizeof(int) * nz, cudaMemAdviseSetReadMostly, 0));
checkCudaErrors(cudaMemAdvise(val, sizeof(float) * nz, cudaMemAdviseSetReadMostly, 0));
cudaMemLocation deviceLoc;
deviceLoc.type = cudaMemLocationTypeDevice;
deviceLoc.id = 0; // Device location with initial device 0
checkCudaErrors(cudaMemAdvise(I, sizeof(int) * (N + 1), cudaMemAdviseSetReadMostly, deviceLoc));
checkCudaErrors(cudaMemAdvise(J, sizeof(int) * nz, cudaMemAdviseSetReadMostly, deviceLoc));
checkCudaErrors(cudaMemAdvise(val, sizeof(float) * nz, cudaMemAdviseSetReadMostly, deviceLoc));
checkCudaErrors(cudaMallocManaged((void **)&x, sizeof(float) * N));
@ -648,26 +651,30 @@ int main(int argc, char **argv)
int offset_p = device_count * totalThreadsPerGPU;
int offset_x = device_count * totalThreadsPerGPU;
checkCudaErrors(cudaMemPrefetchAsync(I, sizeof(int) * N, *deviceId, nStreams[device_count]));
checkCudaErrors(cudaMemPrefetchAsync(val, sizeof(float) * nz, *deviceId, nStreams[device_count]));
checkCudaErrors(cudaMemPrefetchAsync(J, sizeof(float) * nz, *deviceId, nStreams[device_count]));
// Create device location with specific device ID
cudaMemLocation deviceLoc;
deviceLoc.type = cudaMemLocationTypeDevice;
deviceLoc.id = *deviceId;
checkCudaErrors(cudaMemPrefetchAsync(I, sizeof(int) * N, deviceLoc, 0, nStreams[device_count]));
checkCudaErrors(cudaMemPrefetchAsync(val, sizeof(float) * nz, deviceLoc, 0, nStreams[device_count]));
checkCudaErrors(cudaMemPrefetchAsync(J, sizeof(float) * nz, deviceLoc, 0, nStreams[device_count]));
if (offset_Ax <= N) {
for (int i = 0; i < perGPUIter; i++) {
cudaMemAdvise(
Ax + offset_Ax, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, *deviceId);
Ax + offset_Ax, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, deviceLoc);
cudaMemAdvise(
r + offset_r, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, *deviceId);
r + offset_r, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, deviceLoc);
cudaMemAdvise(
x + offset_x, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, *deviceId);
x + offset_x, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, deviceLoc);
cudaMemAdvise(
p + offset_p, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, *deviceId);
p + offset_p, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetPreferredLocation, deviceLoc);
cudaMemAdvise(
Ax + offset_Ax, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, *deviceId);
cudaMemAdvise(r + offset_r, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, *deviceId);
cudaMemAdvise(p + offset_p, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, *deviceId);
cudaMemAdvise(x + offset_x, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, *deviceId);
Ax + offset_Ax, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, deviceLoc);
cudaMemAdvise(r + offset_r, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, deviceLoc);
cudaMemAdvise(p + offset_p, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, deviceLoc);
cudaMemAdvise(x + offset_x, sizeof(float) * totalThreadsPerGPU, cudaMemAdviseSetAccessedBy, deviceLoc);
offset_Ax += totalThreadsPerGPU * kNumGpusRequired;
offset_r += totalThreadsPerGPU * kNumGpusRequired;
@ -739,8 +746,11 @@ int main(int argc, char **argv)
deviceId++;
}
checkCudaErrors(cudaMemPrefetchAsync(x, sizeof(float) * N, cudaCpuDeviceId));
checkCudaErrors(cudaMemPrefetchAsync(dot_result, sizeof(double), cudaCpuDeviceId));
// Use cudaMemLocationTypeHost for optimal host memory location
cudaMemLocation hostLoc;
hostLoc.type = cudaMemLocationTypeHost;
checkCudaErrors(cudaMemPrefetchAsync(x, sizeof(float) * N, hostLoc, 0));
checkCudaErrors(cudaMemPrefetchAsync(dot_result, sizeof(double), hostLoc, 0));
deviceId = bestFitDeviceIds.begin();
device_count = 0;

View File

@ -8,7 +8,7 @@ find_package(CUDAToolkit REQUIRED)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CUDA_ARCHITECTURES 53 72 75 80 86 87 90)
set(CMAKE_CUDA_ARCHITECTURES 75 80 86 87 90)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
if(ENABLE_CUDA_DEBUG)

View File

@ -150,9 +150,10 @@ int main(int argc, char *argv[])
CUcontext context;
CUmodule module;
CUfunction kernel;
CUctxCreateParams ctxCreateParams = {};
CUDA_SAFE_CALL(cuInit(0));
CUDA_SAFE_CALL(cuDeviceGet(&cuDevice, 0));
CUDA_SAFE_CALL(cuCtxCreate(&context, 0, cuDevice));
CUDA_SAFE_CALL(cuCtxCreate(&context, &ctxCreateParams, 0, cuDevice));
// Dynamically determine the arch to link for
int major = 0;

View File

@ -84,13 +84,17 @@ void findMultipleBestGPUs(int &num_of_devices, int *device_ids)
cudaDeviceProp deviceProp;
int devices_prohibited = 0;
int computeMode;
int clockRate;
while (current_device < device_count) {
cudaGetDeviceProperties(&deviceProp, current_device);
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device));
checkCudaErrors(cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, current_device));
// If this GPU is not running on Compute Mode prohibited,
// then we can add it to the list
int sm_per_multiproc;
if (deviceProp.computeMode != cudaComputeModeProhibited) {
if (computeMode != cudaComputeModeProhibited) {
if (deviceProp.major == 9999 && deviceProp.minor == 9999) {
sm_per_multiproc = 1;
}
@ -99,7 +103,7 @@ void findMultipleBestGPUs(int &num_of_devices, int *device_ids)
}
gpu_stats[current_device].compute_perf =
(uint64_t)deviceProp.multiProcessorCount * sm_per_multiproc * deviceProp.clockRate;
(uint64_t)deviceProp.multiProcessorCount * sm_per_multiproc * clockRate;
gpu_stats[current_device].device_id = current_device;
}
else {

View File

@ -94,8 +94,10 @@ int SineWaveSimulation::initCuda(uint8_t *vkDeviceUUID, size_t UUID_SIZE)
// Find the GPU which is selected by Vulkan
while (current_device < device_count) {
cudaGetDeviceProperties(&deviceProp, current_device);
int computeMode;
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device));
if ((deviceProp.computeMode != cudaComputeModeProhibited)) {
if ((computeMode != cudaComputeModeProhibited)) {
// Compare the cuda device UUID with vulkan UUID
int ret = memcmp((void *)&deviceProp.uuid, vkDeviceUUID, UUID_SIZE);
if (ret == 0) {

View File

@ -830,6 +830,7 @@ private:
int devices_prohibited = 0;
cudaDeviceProp deviceProp;
int computeMode;
checkCudaErrors(cudaGetDeviceCount(&device_count));
if (device_count == 0) {
@ -840,8 +841,8 @@ private:
// Find the GPU which is selected by Vulkan
while (current_device < device_count) {
cudaGetDeviceProperties(&deviceProp, current_device);
if ((deviceProp.computeMode != cudaComputeModeProhibited)) {
checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device));
if ((computeMode != cudaComputeModeProhibited)) {
// Compare the cuda device UUID with vulkan UUID
int ret = memcmp(&deviceProp.uuid, &vkDeviceUUID, VK_UUID_SIZE);
if (ret == 0) {

View File

@ -335,9 +335,11 @@ void runMatrixMultiplyKernel(unsigned int matrixDim,
checkCudaErrors(cudaMallocManaged(&dptrA, size));
checkCudaErrors(cudaMallocManaged(&dptrB, size));
checkCudaErrors(cudaMallocManaged(&dptrC, size));
checkCudaErrors(cudaMemPrefetchAsync(dptrA, size, cudaCpuDeviceId));
checkCudaErrors(cudaMemPrefetchAsync(dptrB, size, cudaCpuDeviceId));
checkCudaErrors(cudaMemPrefetchAsync(dptrC, size, cudaCpuDeviceId));
cudaMemLocation hostLoc;
hostLoc.type = cudaMemLocationTypeHost;
checkCudaErrors(cudaMemPrefetchAsync(dptrA, size, hostLoc, 0));
checkCudaErrors(cudaMemPrefetchAsync(dptrB, size, hostLoc, 0));
checkCudaErrors(cudaMemPrefetchAsync(dptrC, size, hostLoc, 0));
}
else {
checkCudaErrors(cudaMallocManaged(&dptrA, size, cudaMemAttachHost));
@ -402,9 +404,12 @@ void runMatrixMultiplyKernel(unsigned int matrixDim,
}
if (hintsRequired) {
if (deviceProp.concurrentManagedAccess) {
checkCudaErrors(cudaMemPrefetchAsync(dptrA, size, device_id, streamToRunOn));
checkCudaErrors(cudaMemPrefetchAsync(dptrB, size, device_id, streamToRunOn));
checkCudaErrors(cudaMemPrefetchAsync(dptrC, size, device_id, streamToRunOn));
cudaMemLocation deviceLoc;
deviceLoc.type = cudaMemLocationTypeDevice;
deviceLoc.id = device_id;
checkCudaErrors(cudaMemPrefetchAsync(dptrA, size, deviceLoc, 0, streamToRunOn));
checkCudaErrors(cudaMemPrefetchAsync(dptrB, size, deviceLoc, 0, streamToRunOn));
checkCudaErrors(cudaMemPrefetchAsync(dptrC, size, deviceLoc, 0, streamToRunOn));
}
else {
checkCudaErrors(cudaStreamAttachMemAsync(streamToRunOn, dptrA, 0, cudaMemAttachGlobal));
@ -437,9 +442,11 @@ void runMatrixMultiplyKernel(unsigned int matrixDim,
sdkStartTimer(&gpuTransferCallsTimer);
if (hintsRequired) {
if (deviceProp.concurrentManagedAccess) {
checkCudaErrors(cudaMemPrefetchAsync(dptrA, size, cudaCpuDeviceId));
checkCudaErrors(cudaMemPrefetchAsync(dptrB, size, cudaCpuDeviceId));
checkCudaErrors(cudaMemPrefetchAsync(dptrC, size, cudaCpuDeviceId));
cudaMemLocation hostLoc;
hostLoc.type = cudaMemLocationTypeHost;
checkCudaErrors(cudaMemPrefetchAsync(dptrA, size, hostLoc, 0));
checkCudaErrors(cudaMemPrefetchAsync(dptrB, size, hostLoc, 0));
checkCudaErrors(cudaMemPrefetchAsync(dptrC, size, hostLoc, 0));
}
else {
checkCudaErrors(cudaStreamAttachMemAsync(streamToRunOn, dptrA, 0, cudaMemAttachHost));

View File

@ -195,7 +195,7 @@ static CUresult buildKernel(CUcontext *phContext, CUdevice *phDevice, CUmodule *
// Initialize CUDA and obtain the device's compute capability.
int major = 0, minor = 0;
*phDevice = cudaDeviceInit(&major, &minor);
checkCudaErrors(cuCtxCreate(phContext, 0, *phDevice));
checkCudaErrors(cuCtxCreate(phContext, NULL, 0, *phDevice));
// Get the NVVM IR from file.
size_t size = 0;

View File

@ -89,7 +89,7 @@ initCUDA(CUcontext *phContext, CUdevice *phDevice, CUmodule *phModule, CUfunctio
assert(phContext && phDevice && phModule && phKernel && ptx);
// Create a CUDA context on the device.
checkCudaErrors(cuCtxCreate(phContext, 0, *phDevice));
checkCudaErrors(cuCtxCreate(phContext, NULL, 0, *phDevice));
// Load the PTX.
checkCudaErrors(cuModuleLoadDataEx(phModule, ptx, 0, 0, 0));

View File

@ -206,7 +206,7 @@ static CUresult buildKernel(CUcontext *phContext, CUdevice *phDevice, CUmodule *
*phDevice = cudaDeviceInit(&major, &minor);
// Create a context on the device.
checkCudaErrors(cuCtxCreate(phContext, 0, *phDevice));
checkCudaErrors(cuCtxCreate(phContext, NULL, 0, *phDevice));
// Get the NVVM IR from file.
size_t size = 0;