diff --git a/next/0001-Optionally-use-hipblaslt.patch b/next/0001-Optionally-use-hipblaslt.patch new file mode 100644 index 0000000..1e5ca4b --- /dev/null +++ b/next/0001-Optionally-use-hipblaslt.patch @@ -0,0 +1,506 @@ +From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001 +From: Tom Rix +Date: Sat, 29 Jun 2024 04:18:34 -0700 +Subject: [PATCH] Optionally use hipblaslt + +--- + aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------ + aten/src/ATen/cuda/CUDAContextLight.h | 4 +++ + aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++-- + aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++--- + aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++- + cmake/Dependencies.cmake | 3 ++ + cmake/public/LoadHIP.cmake | 2 +- + 7 files changed, 82 insertions(+), 19 deletions(-) + +diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp +index ce991a9bcad4..3f0d17b52778 100644 +--- a/aten/src/ATen/cuda/CUDABlas.cpp ++++ b/aten/src/ATen/cuda/CUDABlas.cpp +@@ -14,7 +14,9 @@ + #include + + #ifdef USE_ROCM ++#ifdef USE_HIPBLASLT + #include ++#endif + // until hipblas has an API to accept flags, we must use rocblas here + #include + #include +@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) { + static size_t _parseChosenWorkspaceSize() { + const char * val = getenv("CUBLASLT_WORKSPACE_SIZE"); + #ifdef USE_ROCM ++#ifndef USE_HIPBLASLT ++ return 0; ++#endif + if (!val) { + // accept either env var + val = getenv("HIPBLASLT_WORKSPACE_SIZE"); +@@ -235,6 +240,7 @@ namespace at::cuda::blas { + } while (0) + + ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + namespace { + // Following the pattern of CuSparseDescriptor + // Defined here for now because this is the only place cublas_lt interface is +@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< + }; + } // namespace + +- + template + inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + cudaDataType_t abcType = CUDA_R_32F; +@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + " scaleType ", + scaleType); + } +- ++#endif + + template + inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { +@@ -608,10 +613,13 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)) + template <> + void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)) + { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(float)); + } +- else { ++ else ++#endif ++ { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(float)); + } + } +@@ -651,10 +659,13 @@ void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex + void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) + { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::Half)); + } +- else { ++ else ++#endif ++ { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::Half)); + } + } +@@ -662,10 +673,13 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) + template <> + void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) + { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16)); + } +- else { ++ else ++#endif ++ { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); + } + } +@@ -781,11 +795,13 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + } + } + ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + template + inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // forward to bgemm implementation but set strides and batches to 0 + bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0); + } ++#endif + + template + inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) { +@@ -1008,10 +1024,13 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) + template <> + void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) + { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { + gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); + } +- else { ++ else ++#endif ++ { + gemm_internal_cublas(CUDABLAS_GEMM_ARGS(float)); + } + } +@@ -1051,10 +1070,13 @@ void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex + void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) + { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { + gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); + } +- else { ++ else ++#endif ++ { + gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::Half)); + } + } +@@ -1062,10 +1084,13 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) + template <> + void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) + { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { + gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); + } +- else { ++ else ++#endif ++ { + gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + } +@@ -1177,7 +1202,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { + } + } + +- ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + template + void gemm_and_bias( + bool transpose_mat1, +@@ -1410,7 +1435,7 @@ void scaled_gemm( + ScalarType result_dtype, + void* amax_ptr, + bool use_fast_accum) { +-#if CUDA_VERSION >= 11080 || defined(USE_ROCM) ++#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + const auto computeType = CUBLAS_COMPUTE_32F; + const auto scaleType = CUDA_R_32F; + const int8_t fastAccuMode = use_fast_accum ? 1 : 0; +@@ -1681,6 +1706,7 @@ void int8_gemm( + " scaleType ", + scaleType); + } ++#endif + + template <> + void trsm(CUDABLAS_TRSM_ARGTYPES(float)) { +diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h +index f2b657ced51b..f0ee613c4208 100644 +--- a/aten/src/ATen/cuda/CUDAContextLight.h ++++ b/aten/src/ATen/cuda/CUDAContextLight.h +@@ -9,7 +9,9 @@ + + // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also + // added bf16 support ++#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))) + #include ++#endif + + #ifdef CUDART_VERSION + #include +@@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); + /* Handles */ + TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); + TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); ++#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))) + TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); ++#endif + + TORCH_CUDA_CPP_API void clearCublasWorkspaces(); + +diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp +index 8eac525b3695..abfdf7a23847 100644 +--- a/aten/src/ATen/cuda/CublasHandlePool.cpp ++++ b/aten/src/ATen/cuda/CublasHandlePool.cpp +@@ -29,7 +29,7 @@ namespace at::cuda { + + namespace { + +-#if defined(USE_ROCM) ++#if defined(USE_ROCM) && defined(USE_HIPBLASLT) + void createCublasLtHandle(cublasLtHandle_t *handle) { + TORCH_CUDABLAS_CHECK(cublasLtCreate(handle)); + } +@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() { + return handle; + } + +-cublasLtHandle_t getCurrentCUDABlasLtHandle() { + #ifdef USE_ROCM ++#if defined(USE_HIPBLASLT) ++cublasLtHandle_t getCurrentCUDABlasLtHandle() { + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + +@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() { + + auto handle = myPoolWindow->reserve(device); + return handle; ++} ++#endif + #else ++cublasLtHandle_t getCurrentCUDABlasLtHandle() { + return reinterpret_cast(getCurrentCUDABlasHandle()); +-#endif + } ++#endif + + } // namespace at::cuda +diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h +index 53e6154120c9..fa1d664696db 100644 +--- a/aten/src/ATen/cuda/tunable/TunableGemm.h ++++ b/aten/src/ATen/cuda/tunable/TunableGemm.h +@@ -11,7 +11,9 @@ + + #include + #ifdef USE_ROCM ++#ifdef USE_HIPBLASLT + #include ++#endif + #include + #endif + #include +@@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable> + } + }; + ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + template + class DefaultScaledGemmOp : public Callable> { + public: +@@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable> { + return OK; + } + }; ++#endif + + template + inline bool IsZero(T v) { +@@ -191,6 +195,7 @@ static void AddRocblasValidator() { + } + } + ++#ifdef USE_HIPBLASLT + static void AddHipblasltValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + if (validators.find("HIPBLASLT_VERSION") == validators.end()) { +@@ -205,6 +210,7 @@ static void AddHipblasltValidator() { + [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + } + } ++#endif + + static void AddRocmValidator() { + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); +@@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { + } + AddRocblasValidator(); + } +- ++#ifdef USE_HIPBLASLT + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; +@@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { + } + AddHipblasltValidator(); + } +- ++#endif + if (rocm_validators) { + AddRocmValidator(); + } +@@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp + } + AddRocblasValidator(); + } +- ++#ifdef USE_HIPBLASLT + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + rocm_validators = true; +@@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp + } + AddHipblasltValidator(); + } +- ++#endif + if (rocm_validators) { + AddRocmValidator(); + } +@@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp + } + }; + ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + template + class ScaledGemmTunableOp : public TunableOp, StreamTimer> { + public: +@@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + + #if defined(USE_ROCM) ++#ifdef USE_HIPBLASLT + for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + AddHipblasltValidator(); ++#endif + AddRocmValidator(); + #endif + } +@@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> + "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } + }; ++#endif + + #undef XSTRINGIFY + #undef STRINGIFY +diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp +index 84c59a4fd0d7..56ad5de3bf2d 100644 +--- a/aten/src/ATen/native/cuda/Blas.cpp ++++ b/aten/src/ATen/native/cuda/Blas.cpp +@@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa + } + + static bool getDisableAddmmCudaLt() { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); + #ifdef USE_ROCM + // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) +@@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() { + } + return false; + #endif ++#else ++ return true; ++#endif + } + + #ifdef USE_ROCM + static bool isSupportedHipLtROCmArch(int index) { ++#ifdef USE_HIPBLASLT + hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); + std::string device_arch = prop->gcnArchName; + static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; +@@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) { + } + } + TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!"); ++#endif + return false; + } + #endif +@@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma + at::ScalarType scalar_type = self.scalar_type(); + c10::MaybeOwned self_; + if (&result != &self) { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM) + // Strangely, if mat2 has only 1 row or column, we get + // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. +@@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma + scalar_type != at::ScalarType::BFloat16)); + #endif + } ++#endif + #endif + if (!useLtInterface) { + self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm"); + } + self__sizes = self_->sizes(); + } else { +-#if defined(USE_ROCM) ++#if defined(USE_ROCM) && defined(USE_HIPBLASLT) + useLtInterface = !disable_addmm_cuda_lt && + result.dim() == 2 && result.is_contiguous() && + isSupportedHipLtROCmArch(self.device().index()) && +@@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); + + if (useLtInterface) { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + #if defined(USE_ROCM) + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, +@@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma + activation_epilogue + ); + }); ++#endif + #endif + } else + { +@@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { + } + + static bool _scaled_mm_allowed_device() { ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + auto dprops = at::cuda::getCurrentDeviceProperties(); + #ifdef USE_ROCM + std::string device_arch = dprops->gcnArchName; +@@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() { + #else + return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); + #endif ++#else ++ return false; ++#endif + } + + // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax +@@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, + // Check sizes + bool allowed_device = _scaled_mm_allowed_device(); + TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); ++#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( +@@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, + #if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200 + // ROCm's hipBLASLt does not support amax before 6.2, so calculate separately + amax = at::max(at::abs(out.to(kFloat))); ++#endif + #endif + + return {out, amax}; +diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake +index f1f2eb7cec31..8d05e834bbc5 100644 +--- a/cmake/Dependencies.cmake ++++ b/cmake/Dependencies.cmake +@@ -1052,6 +1052,9 @@ if(USE_ROCM) + list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP) + list(APPEND HIP_CXX_FLAGS -std=c++17) + list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2) ++ if(hipblast_FOUND) ++ list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT) ++ endif() + if(HIP_NEW_TYPE_ENUMS) + list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS) + endif() +diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake +index fa39156031ff..df4836847fdf 100644 +--- a/cmake/public/LoadHIP.cmake ++++ b/cmake/public/LoadHIP.cmake +@@ -155,7 +155,7 @@ if(HIP_FOUND) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(rocblas REQUIRED) + find_package_and_print_version(hipblas REQUIRED) +- find_package_and_print_version(hipblaslt REQUIRED) ++ find_package_and_print_version(hipblaslt) + find_package_and_print_version(miopen REQUIRED) + find_package_and_print_version(hipfft REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) +-- +2.45.2 + diff --git a/next/0001-disable-use-of-aotriton.patch b/next/0001-disable-use-of-aotriton.patch new file mode 100644 index 0000000..61ffd1e --- /dev/null +++ b/next/0001-disable-use-of-aotriton.patch @@ -0,0 +1,94 @@ +From 038ce9e44776e23f21c1816daa259bc0ea335088 Mon Sep 17 00:00:00 2001 +From: Tom Rix +Date: Sat, 29 Jun 2024 07:06:09 -0700 +Subject: [PATCH] disable use of aotriton + +--- + .../ATen/native/transformers/cuda/sdp_utils.cpp | 17 +++++++++++++++-- + 1 file changed, 15 insertions(+), 2 deletions(-) + +diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +index 214b02d8262e..7b3eb9dcd8cd 100644 +--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp ++++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +@@ -19,9 +19,12 @@ + #include + #include + ++#ifdef USE_FLASH_ATTENTION + #if USE_ROCM + #include + #endif ++#endif ++ + + /** + * Note [SDPA Runtime Dispatch] +@@ -182,6 +185,9 @@ bool check_sm_version(cudaDeviceProp * dprops) { + + bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) { + // Check that the gpu is capable of running flash attention ++#ifndef USE_FLASH_ATTENTION ++ return false; ++#else + using sm80 = SMVersion<8, 0>; + using sm90 = SMVersion<9, 0>; + #if USE_ROCM +@@ -209,9 +215,13 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug + } + #endif + return true; ++#endif + } + + bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) { ++#ifndef USE_FLASH_ATTENTION ++ return false; ++#else + // Mem Efficient attention supports hardware in the range [sm_50, sm_90] + using sm50 = SMVersion<5, 0>; + using sm90 = SMVersion<9, 0>; +@@ -240,6 +250,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) + } + #endif + return true; ++#endif + } + + bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89( +@@ -554,7 +565,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { + #ifndef USE_FLASH_ATTENTION + TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention."); + return false; +-#endif ++#else + + // Define gate functions that determine if a flash kernel can be ran + // Replace with std::to_array when we migrate to c++20 +@@ -597,13 +608,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { + } + } + return true; ++#endif + } + + bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { + #ifndef USE_MEM_EFF_ATTENTION + TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention."); + return false; +-#endif ++#else + // Constraints specific to mem efficient attention + constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes = + array_of(at::kHalf, at::kFloat, at::kBFloat16); +@@ -663,6 +675,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { + } + #endif + return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); ++#endif + } + + SDPBackend select_sdp_backend(sdp_params const& kernel_params) { +-- +2.45.2 +