120 lines
5 KiB
Diff
120 lines
5 KiB
Diff
From 0f33e0a7bbd1522ee74f8fc1fbe3af7563318c79 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <Tom.Rix@amd.com>
|
|
Date: Fri, 28 Mar 2025 15:33:09 -0700
|
|
Subject: [PATCH] Add cmake varaible USE_ROCM_CK
|
|
|
|
To control the use of ROCm Composable Kernel usage.
|
|
|
|
CK is not compatible with all rocBLAS gpu's, so the user
|
|
must explicitly choose to use CK.
|
|
|
|
Signed-off-by: Tom Rix <Tom.Rix@amd.com>
|
|
---
|
|
CMakeLists.txt | 1 +
|
|
aten/src/ATen/CMakeLists.txt | 8 ++++++--
|
|
aten/src/ATen/cuda/CUDABlas.cpp | 10 +++++-----
|
|
cmake/Dependencies.cmake | 3 +++
|
|
4 files changed, 15 insertions(+), 7 deletions(-)
|
|
|
|
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
|
index f3fee2f7ffc2..73903acce452 100644
|
|
--- a/CMakeLists.txt
|
|
+++ b/CMakeLists.txt
|
|
@@ -249,6 +249,7 @@ cmake_dependent_option(
|
|
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
|
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
|
|
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
|
|
+cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON)
|
|
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
|
|
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
|
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
|
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
|
index 085af373ec22..af268ab88572 100644
|
|
--- a/aten/src/ATen/CMakeLists.txt
|
|
+++ b/aten/src/ATen/CMakeLists.txt
|
|
@@ -361,13 +361,17 @@ endif()
|
|
${native_quantized_hip_hip}
|
|
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
|
)
|
|
- if(WIN32) # Windows doesn't support Composable Kernels and Triton
|
|
+ if(NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels and Triton
|
|
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
|
|
file(GLOB native_hip_ck "native/hip/ck*.hip")
|
|
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
|
- ${native_hip_bgemm} ${native_hip_ck}
|
|
+ ${native_hip_bgemm} ${native_hip_ck})
|
|
+ endif()
|
|
+ if(WIN32) # Windows doesn't support Composable Kernels and Triton
|
|
+ exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
|
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
|
|
endif()
|
|
+
|
|
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
|
list(APPEND all_hip_cpp
|
|
${native_nested_hip_cpp}
|
|
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
index a62b028fd4ff..a3dbf76848ea 100644
|
|
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
|
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
@@ -708,7 +708,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
|
}
|
|
@@ -1061,7 +1061,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
|
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
|
|
#endif
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
|
|
}
|
|
@@ -1077,7 +1077,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
|
}
|
|
@@ -1125,7 +1125,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
@@ -1141,7 +1141,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
|
index 30917bdf39f5..2ca6091030f1 100644
|
|
--- a/cmake/Dependencies.cmake
|
|
+++ b/cmake/Dependencies.cmake
|
|
@@ -1046,6 +1046,9 @@ if(USE_ROCM)
|
|
if(HIPBLASLT_VEC_EXT)
|
|
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
|
|
endif()
|
|
+ if(USE_ROCM_CK)
|
|
+ list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK)
|
|
+ endif()
|
|
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
|
|
if(WIN32)
|
|
add_definitions(-DROCM_ON_WINDOWS)
|
|
--
|
|
2.48.1
|
|
|