42 lines
1.7 KiB
Diff
42 lines
1.7 KiB
Diff
From 214dc959acc809e1959643272c344ee5335d5a69 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <trix@redhat.com>
|
|
Date: Thu, 1 Feb 2024 11:29:47 -0500
|
|
Subject: [PATCH] cuda - hip signatures
|
|
|
|
---
|
|
aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 9 +++++++++
|
|
1 file changed, 9 insertions(+)
|
|
|
|
diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
|
index 1b85e7776e..bb6f88783a 100644
|
|
--- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
|
+++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
|
@@ -134,8 +134,13 @@ nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
|
|
const char *src,
|
|
const char *name,
|
|
int numHeaders,
|
|
+#if !defined(USE_ROCM)
|
|
const char * const *headers,
|
|
const char * const *includeNames) {
|
|
+#else
|
|
+ const char **headers,
|
|
+ const char **includeNames) {
|
|
+#endif
|
|
auto fn = reinterpret_cast<decltype(&nvrtcCreateProgram)>(getNVRTCLibrary().sym(__func__));
|
|
if (!fn)
|
|
throw std::runtime_error("Can't get nvrtcCreateProgram");
|
|
@@ -150,7 +155,11 @@ NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *);
|
|
NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *);
|
|
NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *);
|
|
#endif
|
|
+#if !defined(USE_ROCM)
|
|
NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *);
|
|
+#else
|
|
+NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char **);
|
|
+#endif
|
|
_STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult);
|
|
NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*);
|
|
NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *);
|
|
--
|
|
2.43.0
|
|
|