From 193854993cd939de186de19589c1add4c4b2cf66 Mon Sep 17 00:00:00 2001 From: Tom Rix Date: Mon, 21 Jul 2025 11:35:03 -0700 Subject: [PATCH] Add cmake variable USE_ROCM_CK --- CMakeLists.txt | 1 + aten/src/ATen/CMakeLists.txt | 40 ++++++++++++++++----------------- aten/src/ATen/cuda/CUDABlas.cpp | 22 +++++++++--------- cmake/Dependencies.cmake | 3 +++ 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a5d25e6afa0f..afc1b53efa64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,6 +240,7 @@ cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF) +cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF) cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index c9cfd74b501e..59f6178218ee 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -373,26 +373,26 @@ if(USE_ROCM) # is header only, so this should be ok, except that the CMake build generates # a ck/config.h. We just do that part here. Without this, the ck.h from the # ROCM SDK may get accidentally used instead. - function(_pytorch_rocm_generate_ck_conf) - set(CK_ENABLE_INT8 "ON") - set(CK_ENABLE_FP16 "ON") - set(CK_ENABLE_FP32 "ON") - set(CK_ENABLE_FP64 "ON") - set(CK_ENABLE_BF16 "ON") - set(CK_ENABLE_FP8 "ON") - set(CK_ENABLE_BF8 "ON") - set(CK_USE_XDL "ON") - set(CK_USE_WMMA "ON") - configure_file( - "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" - "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" - ) - endfunction() +# function(_pytorch_rocm_generate_ck_conf) +# set(CK_ENABLE_INT8 "ON") +# set(CK_ENABLE_FP16 "ON") +# set(CK_ENABLE_FP32 "ON") +# set(CK_ENABLE_FP64 "ON") +# set(CK_ENABLE_BF16 "ON") +# set(CK_ENABLE_FP8 "ON") +# set(CK_ENABLE_BF8 "ON") +# set(CK_USE_XDL "ON") +# set(CK_USE_WMMA "ON") +# configure_file( +# "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" +# "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" +# ) +# endfunction() list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) - list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) - _pytorch_rocm_generate_ck_conf() +# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) +# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) +# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) +# _pytorch_rocm_generate_ck_conf() # Next two lines are needed because TunableOp uses third-party/fmt list(APPEND ATen_HIP_INCLUDE $) @@ -409,7 +409,7 @@ endif() ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) - if(WIN32) # Windows doesn't support Composable Kernels + if(NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 89350a11bea7..e5b7960177cf 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -752,7 +752,7 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { -#ifdef USE_ROCM +#ifdef USE_ROCM_CK // hipblaslt does not support double gemm yet bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(double)); #else @@ -836,7 +836,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::bgemm_internal_ck(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } @@ -1270,14 +1270,14 @@ template <> void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { -#ifdef USE_ROCM +#ifdef USE_ROCM_CK // hipblaslt does not support double gemm yet gemm_internal_cublas(CUDABLAS_GEMM_ARGS(double)); #else gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); } @@ -1293,7 +1293,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100 gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); @@ -1311,7 +1311,7 @@ template <> void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { -#ifdef USE_ROCM +#ifdef USE_ROCM_CK // hipblaslt does not support complex gemm yet gemm_internal_cublas>(CUDABLAS_GEMM_ARGS(c10::complex)); #else @@ -1327,7 +1327,7 @@ template <> void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { -#ifdef USE_ROCM +#ifdef USE_ROCM_CK // hipblaslt does not support complex gemm yet gemm_internal_cublas>(CUDABLAS_GEMM_ARGS(c10::complex)); #else @@ -1345,7 +1345,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1361,7 +1361,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); } @@ -1382,7 +1382,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm"); } @@ -1398,7 +1398,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::B if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } -#if defined(USE_ROCM) && !defined(_MSC_VER) +#if defined(USE_ROCM) && defined(USE_ROCM_CK) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm"); } diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a93386c27f8d..be1368999d38 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1031,6 +1031,9 @@ if(USE_ROCM) if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() + if(USE_ROCM_CK) + list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK) + endif() list(APPEND HIP_HIPCC_FLAGS --offload-compress) if(WIN32) add_definitions(-DROCM_ON_WINDOWS) -- 2.49.0