202 lines
9.2 KiB
Diff
202 lines
9.2 KiB
Diff
From 193854993cd939de186de19589c1add4c4b2cf66 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <Tom.Rix@amd.com>
|
|
Date: Mon, 21 Jul 2025 11:35:03 -0700
|
|
Subject: [PATCH] Add cmake variable USE_ROCM_CK
|
|
|
|
---
|
|
CMakeLists.txt | 1 +
|
|
aten/src/ATen/CMakeLists.txt | 40 ++++++++++++++++-----------------
|
|
aten/src/ATen/cuda/CUDABlas.cpp | 22 +++++++++---------
|
|
cmake/Dependencies.cmake | 3 +++
|
|
4 files changed, 35 insertions(+), 31 deletions(-)
|
|
|
|
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
|
index a5d25e6afa0f..afc1b53efa64 100644
|
|
--- a/CMakeLists.txt
|
|
+++ b/CMakeLists.txt
|
|
@@ -240,6 +240,7 @@ cmake_dependent_option(
|
|
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
|
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
|
|
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
|
|
+cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON)
|
|
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
|
|
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
|
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
|
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
|
index c9cfd74b501e..59f6178218ee 100644
|
|
--- a/aten/src/ATen/CMakeLists.txt
|
|
+++ b/aten/src/ATen/CMakeLists.txt
|
|
@@ -373,26 +373,26 @@ if(USE_ROCM)
|
|
# is header only, so this should be ok, except that the CMake build generates
|
|
# a ck/config.h. We just do that part here. Without this, the ck.h from the
|
|
# ROCM SDK may get accidentally used instead.
|
|
- function(_pytorch_rocm_generate_ck_conf)
|
|
- set(CK_ENABLE_INT8 "ON")
|
|
- set(CK_ENABLE_FP16 "ON")
|
|
- set(CK_ENABLE_FP32 "ON")
|
|
- set(CK_ENABLE_FP64 "ON")
|
|
- set(CK_ENABLE_BF16 "ON")
|
|
- set(CK_ENABLE_FP8 "ON")
|
|
- set(CK_ENABLE_BF8 "ON")
|
|
- set(CK_USE_XDL "ON")
|
|
- set(CK_USE_WMMA "ON")
|
|
- configure_file(
|
|
- "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
|
|
- "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
|
|
- )
|
|
- endfunction()
|
|
+# function(_pytorch_rocm_generate_ck_conf)
|
|
+# set(CK_ENABLE_INT8 "ON")
|
|
+# set(CK_ENABLE_FP16 "ON")
|
|
+# set(CK_ENABLE_FP32 "ON")
|
|
+# set(CK_ENABLE_FP64 "ON")
|
|
+# set(CK_ENABLE_BF16 "ON")
|
|
+# set(CK_ENABLE_FP8 "ON")
|
|
+# set(CK_ENABLE_BF8 "ON")
|
|
+# set(CK_USE_XDL "ON")
|
|
+# set(CK_USE_WMMA "ON")
|
|
+# configure_file(
|
|
+# "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
|
|
+# "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
|
|
+# )
|
|
+# endfunction()
|
|
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
|
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
|
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
|
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
|
- _pytorch_rocm_generate_ck_conf()
|
|
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
|
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
|
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
|
+# _pytorch_rocm_generate_ck_conf()
|
|
|
|
# Next two lines are needed because TunableOp uses third-party/fmt
|
|
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
|
|
@@ -409,7 +409,7 @@ endif()
|
|
${native_quantized_hip_hip}
|
|
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
|
)
|
|
- if(WIN32) # Windows doesn't support Composable Kernels
|
|
+ if(NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels
|
|
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
|
|
file(GLOB native_hip_ck "native/hip/ck*.hip")
|
|
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
|
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
index 89350a11bea7..e5b7960177cf 100644
|
|
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
|
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
@@ -752,7 +752,7 @@ template <>
|
|
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
|
|
{
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
// hipblaslt does not support double gemm yet
|
|
bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
|
|
#else
|
|
@@ -836,7 +836,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
|
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
|
}
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
|
}
|
|
@@ -1270,14 +1270,14 @@ template <>
|
|
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
|
{
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
// hipblaslt does not support double gemm yet
|
|
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
|
|
#else
|
|
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
|
|
#endif
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
|
|
}
|
|
@@ -1293,7 +1293,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
|
|
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
|
@@ -1311,7 +1311,7 @@ template <>
|
|
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
|
|
{
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
// hipblaslt does not support complex gemm yet
|
|
gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
|
|
#else
|
|
@@ -1327,7 +1327,7 @@ template <>
|
|
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
|
|
{
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_CK
|
|
// hipblaslt does not support complex gemm yet
|
|
gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
|
|
#else
|
|
@@ -1345,7 +1345,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
@@ -1361,7 +1361,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
@@ -1382,7 +1382,7 @@ void gemm_internal<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half,
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::Half, float>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
|
|
}
|
|
@@ -1398,7 +1398,7 @@ void gemm_internal<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::B
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::BFloat16, float>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
-#if defined(USE_ROCM) && !defined(_MSC_VER)
|
|
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
|
|
}
|
|
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
|
index a93386c27f8d..be1368999d38 100644
|
|
--- a/cmake/Dependencies.cmake
|
|
+++ b/cmake/Dependencies.cmake
|
|
@@ -1031,6 +1031,9 @@ if(USE_ROCM)
|
|
if(HIPBLASLT_VEC_EXT)
|
|
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
|
|
endif()
|
|
+ if(USE_ROCM_CK)
|
|
+ list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK)
|
|
+ endif()
|
|
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
|
|
if(WIN32)
|
|
add_definitions(-DROCM_ON_WINDOWS)
|
|
--
|
|
2.49.0
|
|
|