python-torch/0001-disable-use-of-aotriton.patch
Tom Rix 86185b46a2 PyTorch 2.4
Signed-off-by: Tom Rix <trix@redhat.com>
2024-07-25 16:33:22 -06:00

94 lines
2.9 KiB
Diff

From 038ce9e44776e23f21c1816daa259bc0ea335088 Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Sat, 29 Jun 2024 07:06:09 -0700
Subject: [PATCH] disable use of aotriton
---
.../ATen/native/transformers/cuda/sdp_utils.cpp | 17 +++++++++++++++--
1 file changed, 15 insertions(+), 2 deletions(-)
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index 214b02d8262e..7b3eb9dcd8cd 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -19,9 +19,12 @@
#include <c10/core/SymInt.h>
#include <c10/util/string_view.h>
+#ifdef USE_FLASH_ATTENTION
#if USE_ROCM
#include <aotriton/flash.h>
#endif
+#endif
+
/**
* Note [SDPA Runtime Dispatch]
@@ -182,6 +185,9 @@ bool check_sm_version(cudaDeviceProp * dprops) {
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
// Check that the gpu is capable of running flash attention
+#ifndef USE_FLASH_ATTENTION
+ return false;
+#else
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
@@ -209,9 +215,13 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
#endif
return true;
+#endif
}
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
+#ifndef USE_FLASH_ATTENTION
+ return false;
+#else
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
@@ -240,6 +250,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
#endif
return true;
+#endif
}
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
@@ -554,7 +565,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
#ifndef USE_FLASH_ATTENTION
TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
return false;
-#endif
+#else
// Define gate functions that determine if a flash kernel can be ran
// Replace with std::to_array when we migrate to c++20
@@ -597,13 +608,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
}
return true;
+#endif
}
bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
#ifndef USE_MEM_EFF_ATTENTION
TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
return false;
-#endif
+#else
// Constraints specific to mem efficient attention
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
@@ -663,6 +675,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
#endif
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
+#endif
}
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
--
2.45.2