112 lines
4.9 KiB
Diff
112 lines
4.9 KiB
Diff
From 027dad1eaed51c1172e2497da611e3267d42d2f0 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <Tom.Rix@amd.com>
|
|
Date: Fri, 28 Mar 2025 09:16:03 -0700
|
|
Subject: [PATCH] python-torch: disable ck
|
|
|
|
---
|
|
aten/src/ATen/CMakeLists.txt | 7 +++----
|
|
aten/src/ATen/Context.cpp | 1 +
|
|
aten/src/ATen/cuda/CUDABlas.cpp | 10 +++++-----
|
|
3 files changed, 9 insertions(+), 9 deletions(-)
|
|
|
|
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
|
index 085af373ec22..84808880e51c 100644
|
|
--- a/aten/src/ATen/CMakeLists.txt
|
|
+++ b/aten/src/ATen/CMakeLists.txt
|
|
@@ -134,7 +134,7 @@ file(GLOB native_cuda_cu "native/cuda/*.cu")
|
|
file(GLOB native_cuda_cpp "native/cuda/*.cpp")
|
|
file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh")
|
|
file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp")
|
|
-file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" "native/hip/bgemm_kernels/*.h")
|
|
+file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" )
|
|
file(GLOB native_cudnn_cpp "native/cudnn/*.cpp")
|
|
file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu")
|
|
file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
|
|
@@ -145,7 +145,7 @@ file(GLOB native_nested_h "native/nested/*.h")
|
|
file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
|
|
file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
|
|
|
|
-file(GLOB native_hip_hip "native/hip/*.hip" "native/hip/bgemm_kernels/*.hip")
|
|
+file(GLOB native_hip_hip "native/hip/*.hip" )
|
|
file(GLOB native_hip_cpp "native/hip/*.cpp")
|
|
file(GLOB native_hip_linalg_cpp "native/hip/linalg/*.cpp")
|
|
file(GLOB native_miopen_cpp "native/miopen/*.cpp")
|
|
@@ -361,13 +361,12 @@ endif()
|
|
${native_quantized_hip_hip}
|
|
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
|
)
|
|
- if(WIN32) # Windows doesn't support Composable Kernels and Triton
|
|
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
|
|
file(GLOB native_hip_ck "native/hip/ck*.hip")
|
|
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
|
${native_hip_bgemm} ${native_hip_ck}
|
|
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
|
|
- endif()
|
|
+
|
|
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
|
list(APPEND all_hip_cpp
|
|
${native_nested_hip_cpp}
|
|
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
|
|
index f598fc3a39d3..03dab6ff38fe 100644
|
|
--- a/aten/src/ATen/Context.cpp
|
|
+++ b/aten/src/ATen/Context.cpp
|
|
@@ -355,6 +355,7 @@ at::BlasBackend Context::blasPreferredBackend() {
|
|
}
|
|
|
|
void Context::setBlasPreferredBackend(at::BlasBackend b) {
|
|
+ return;
|
|
#ifdef _MSC_VER
|
|
TORCH_WARN_ONCE(
|
|
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
|
|
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
index a62b028fd4ff..cba38426ea1f 100644
|
|
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
|
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
@@ -708,7 +708,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_NO_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
|
}
|
|
@@ -1061,7 +1061,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
|
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
|
|
#endif
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_NO_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
|
|
}
|
|
@@ -1077,7 +1077,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_NO_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
|
}
|
|
@@ -1125,7 +1125,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_NO_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
|
}
|
|
@@ -1141,7 +1141,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
|
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
|
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
-#ifdef USE_ROCM
|
|
+#ifdef USE_ROCM_NO_CK
|
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
|
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
|
}
|
|
--
|
|
2.48.1
|
|
|