94 lines
2.9 KiB
Diff
94 lines
2.9 KiB
Diff
From 038ce9e44776e23f21c1816daa259bc0ea335088 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <trix@redhat.com>
|
|
Date: Sat, 29 Jun 2024 07:06:09 -0700
|
|
Subject: [PATCH] disable use of aotriton
|
|
|
|
---
|
|
.../ATen/native/transformers/cuda/sdp_utils.cpp | 17 +++++++++++++++--
|
|
1 file changed, 15 insertions(+), 2 deletions(-)
|
|
|
|
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
|
index 214b02d8262e..7b3eb9dcd8cd 100644
|
|
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
|
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
|
@@ -19,9 +19,12 @@
|
|
#include <c10/core/SymInt.h>
|
|
#include <c10/util/string_view.h>
|
|
|
|
+#ifdef USE_FLASH_ATTENTION
|
|
#if USE_ROCM
|
|
#include <aotriton/flash.h>
|
|
#endif
|
|
+#endif
|
|
+
|
|
|
|
/**
|
|
* Note [SDPA Runtime Dispatch]
|
|
@@ -182,6 +185,9 @@ bool check_sm_version(cudaDeviceProp * dprops) {
|
|
|
|
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
|
|
// Check that the gpu is capable of running flash attention
|
|
+#ifndef USE_FLASH_ATTENTION
|
|
+ return false;
|
|
+#else
|
|
using sm80 = SMVersion<8, 0>;
|
|
using sm90 = SMVersion<9, 0>;
|
|
#if USE_ROCM
|
|
@@ -209,9 +215,13 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
|
}
|
|
#endif
|
|
return true;
|
|
+#endif
|
|
}
|
|
|
|
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
|
|
+#ifndef USE_FLASH_ATTENTION
|
|
+ return false;
|
|
+#else
|
|
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
|
|
using sm50 = SMVersion<5, 0>;
|
|
using sm90 = SMVersion<9, 0>;
|
|
@@ -240,6 +250,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
|
}
|
|
#endif
|
|
return true;
|
|
+#endif
|
|
}
|
|
|
|
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
|
|
@@ -554,7 +565,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
|
#ifndef USE_FLASH_ATTENTION
|
|
TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
|
|
return false;
|
|
-#endif
|
|
+#else
|
|
|
|
// Define gate functions that determine if a flash kernel can be ran
|
|
// Replace with std::to_array when we migrate to c++20
|
|
@@ -597,13 +608,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
|
}
|
|
}
|
|
return true;
|
|
+#endif
|
|
}
|
|
|
|
bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
|
#ifndef USE_MEM_EFF_ATTENTION
|
|
TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
|
|
return false;
|
|
-#endif
|
|
+#else
|
|
// Constraints specific to mem efficient attention
|
|
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
|
|
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
|
@@ -663,6 +675,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
|
}
|
|
#endif
|
|
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
|
|
+#endif
|
|
}
|
|
|
|
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
|
|
--
|
|
2.45.2
|
|
|