46 lines
1.3 KiB
Diff
46 lines
1.3 KiB
Diff
From 33d48f71db7530f00dbd8cff281b65aa8b355b2a Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <trix@redhat.com>
|
|
Date: Tue, 19 Mar 2024 11:32:37 -0400
|
|
Subject: [PATCH] disable use of aotriton
|
|
|
|
---
|
|
aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 6 ++++++
|
|
1 file changed, 6 insertions(+)
|
|
|
|
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
|
index 96b839820efd..2d3dd0cb4b0f 100644
|
|
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
|
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
|
@@ -21,9 +21,11 @@
|
|
#include <cmath>
|
|
#include <functional>
|
|
|
|
+#ifdef USE_FLASH_ATTENTION
|
|
#if USE_ROCM
|
|
#include <aotriton/flash.h>
|
|
#endif
|
|
+#endif
|
|
|
|
/**
|
|
* Note [SDPA Runtime Dispatch]
|
|
@@ -183,6 +185,7 @@ bool check_sm_version(cudaDeviceProp * dprops) {
|
|
}
|
|
|
|
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
|
|
+#ifdef USE_FLASH_ATTENTION
|
|
// Check that the gpu is capable of running flash attention
|
|
using sm80 = SMVersion<8, 0>;
|
|
using sm90 = SMVersion<9, 0>;
|
|
@@ -211,6 +214,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
|
}
|
|
#endif
|
|
return true;
|
|
+#else
|
|
+ return false;
|
|
+#endif
|
|
}
|
|
|
|
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
|
|
--
|
|
2.44.0
|
|
|