python-torch/0001-python-torch-disable-ck.patch
Tom Rix e80f34f74d Update gitcommit to 2.7-rc3
Signed-off-by: Tom Rix <Tom.Rix@amd.com>
2025-03-29 05:11:02 -07:00

112 lines
4.9 KiB
Diff

From 027dad1eaed51c1172e2497da611e3267d42d2f0 Mon Sep 17 00:00:00 2001
From: Tom Rix <Tom.Rix@amd.com>
Date: Fri, 28 Mar 2025 09:16:03 -0700
Subject: [PATCH] python-torch: disable ck
---
aten/src/ATen/CMakeLists.txt | 7 +++----
aten/src/ATen/Context.cpp | 1 +
aten/src/ATen/cuda/CUDABlas.cpp | 10 +++++-----
3 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index 085af373ec22..84808880e51c 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -134,7 +134,7 @@ file(GLOB native_cuda_cu "native/cuda/*.cu")
file(GLOB native_cuda_cpp "native/cuda/*.cpp")
file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh")
file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp")
-file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" "native/hip/bgemm_kernels/*.h")
+file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" )
file(GLOB native_cudnn_cpp "native/cudnn/*.cpp")
file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu")
file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
@@ -145,7 +145,7 @@ file(GLOB native_nested_h "native/nested/*.h")
file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
-file(GLOB native_hip_hip "native/hip/*.hip" "native/hip/bgemm_kernels/*.hip")
+file(GLOB native_hip_hip "native/hip/*.hip" )
file(GLOB native_hip_cpp "native/hip/*.cpp")
file(GLOB native_hip_linalg_cpp "native/hip/linalg/*.cpp")
file(GLOB native_miopen_cpp "native/miopen/*.cpp")
@@ -361,13 +361,12 @@ endif()
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
- if(WIN32) # Windows doesn't support Composable Kernels and Triton
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_hip_bgemm} ${native_hip_ck}
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
- endif()
+
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
list(APPEND all_hip_cpp
${native_nested_hip_cpp}
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index f598fc3a39d3..03dab6ff38fe 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -355,6 +355,7 @@ at::BlasBackend Context::blasPreferredBackend() {
}
void Context::setBlasPreferredBackend(at::BlasBackend b) {
+ return;
#ifdef _MSC_VER
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index a62b028fd4ff..cba38426ea1f 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -708,7 +708,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM_NO_CK
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
@@ -1061,7 +1061,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM_NO_CK
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
@@ -1077,7 +1077,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM_NO_CK
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
}
@@ -1125,7 +1125,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM_NO_CK
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
@@ -1141,7 +1141,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM_NO_CK
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
--
2.48.1