python-torch/0001-add-rocm_version-fallback.patch
Tom Rix ba89707404 Stub in rocm to test in flight packages
Signed-off-by: Tom Rix <trix@redhat.com>
2023-12-20 16:07:51 -05:00

36 lines
1.3 KiB
Diff

From 1d35a0b1f5cb39fd0c44a486157dc739a02c71b6 Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Wed, 20 Dec 2023 11:23:18 -0500
Subject: [PATCH] add rocm_version fallback
Signed-off-by: Tom Rix <trix@redhat.com>
---
torch/utils/hipify/cuda_to_hip_mappings.py | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py
index 73586440e7..9354057a39 100644
--- a/torch/utils/hipify/cuda_to_hip_mappings.py
+++ b/torch/utils/hipify/cuda_to_hip_mappings.py
@@ -57,6 +57,18 @@ if os.path.isfile(rocm_version_h):
if match:
patch = int(match.group(1))
rocm_version = (major, minor, patch)
+else:
+ try:
+ hip_version = subprocess.check_output(["hipconfig", "--version"]).decode("utf-8")
+ hip_split = hip_version.split('.')
+ rocm_version = (int(hip_split[0]), int(hip_split[1]), 0)
+ except subprocess.CalledProcessError:
+ print(f"Warning: hipconfig --version failed")
+ except (FileNotFoundError, PermissionError, NotADirectoryError):
+ # Do not print warning. This is okay. This file can also be imported for non-ROCm builds.
+ pass
+
+
# List of math functions that should be replaced inside device code only.
MATH_TRANSPILATIONS = collections.OrderedDict(
--
2.43.0