From 027dad1eaed51c1172e2497da611e3267d42d2f0 Mon Sep 17 00:00:00 2001 From: Tom Rix Date: Fri, 28 Mar 2025 09:16:03 -0700 Subject: [PATCH] python-torch: disable ck --- aten/src/ATen/CMakeLists.txt | 7 +++---- aten/src/ATen/Context.cpp | 1 + aten/src/ATen/cuda/CUDABlas.cpp | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 085af373ec22..84808880e51c 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -134,7 +134,7 @@ file(GLOB native_cuda_cu "native/cuda/*.cu") file(GLOB native_cuda_cpp "native/cuda/*.cpp") file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh") file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp") -file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" "native/hip/bgemm_kernels/*.h") +file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" ) file(GLOB native_cudnn_cpp "native/cudnn/*.cpp") file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu") file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp") @@ -145,7 +145,7 @@ file(GLOB native_nested_h "native/nested/*.h") file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu") file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp") -file(GLOB native_hip_hip "native/hip/*.hip" "native/hip/bgemm_kernels/*.hip") +file(GLOB native_hip_hip "native/hip/*.hip" ) file(GLOB native_hip_cpp "native/hip/*.cpp") file(GLOB native_hip_linalg_cpp "native/hip/linalg/*.cpp") file(GLOB native_miopen_cpp "native/miopen/*.cpp") @@ -361,13 +361,12 @@ endif() ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) - if(WIN32) # Windows doesn't support Composable Kernels and Triton file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" ${native_hip_bgemm} ${native_hip_ck} ${native_transformers_hip_hip} ${native_transformers_hip_cpp}) - endif() + # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) list(APPEND all_hip_cpp ${native_nested_hip_cpp} diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index f598fc3a39d3..03dab6ff38fe 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -355,6 +355,7 @@ at::BlasBackend Context::blasPreferredBackend() { } void Context::setBlasPreferredBackend(at::BlasBackend b) { + return; #ifdef _MSC_VER TORCH_WARN_ONCE( "torch.backends.cuda.preferred_blas_library is an experimental feature. " diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index a62b028fd4ff..cba38426ea1f 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -708,7 +708,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } -#ifdef USE_ROCM +#ifdef USE_ROCM_NO_CK else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::bgemm_internal_ck(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } @@ -1061,7 +1061,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } -#ifdef USE_ROCM +#ifdef USE_ROCM_NO_CK else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); } @@ -1077,7 +1077,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } -#ifdef USE_ROCM +#ifdef USE_ROCM_NO_CK else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(float)); } @@ -1125,7 +1125,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } -#ifdef USE_ROCM +#ifdef USE_ROCM_NO_CK else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1141,7 +1141,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } -#ifdef USE_ROCM +#ifdef USE_ROCM_NO_CK else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); } -- 2.48.1