88 lines
4 KiB
Diff
88 lines
4 KiB
Diff
From f646e0f04ae591c8f2d8a0cd24b035725c57659b Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <Tom.Rix@amd.com>
|
|
Date: Thu, 23 Jan 2025 08:24:22 -0800
|
|
Subject: [PATCH] torch: paper over c++ assert
|
|
|
|
---
|
|
aten/src/ATen/native/sparse/FlattenIndicesCommon.h | 2 ++
|
|
.../ATen/native/sparse/SparseBinaryOpIntersectionCommon.h | 5 +++++
|
|
.../src/ATen/native/sparse/ValidateCompressedIndicesCommon.h | 2 ++
|
|
3 files changed, 9 insertions(+)
|
|
|
|
diff --git a/aten/src/ATen/native/sparse/FlattenIndicesCommon.h b/aten/src/ATen/native/sparse/FlattenIndicesCommon.h
|
|
index 0e79ed809ae6..a3cec8aaf78b 100644
|
|
--- a/aten/src/ATen/native/sparse/FlattenIndicesCommon.h
|
|
+++ b/aten/src/ATen/native/sparse/FlattenIndicesCommon.h
|
|
@@ -69,11 +69,13 @@ Tensor _flatten_indices_impl(const Tensor& indices, IntArrayRef size) {
|
|
[=] FUNCAPI (int64_t nnz_idx) -> int64_t {
|
|
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
|
|
auto hash = static_cast<int64_t>(0);
|
|
+#if 0
|
|
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
|
|
const auto dim_hash_coeff = hash_coeffs[dim];
|
|
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
|
|
hash += dim_index * dim_hash_coeff;
|
|
}
|
|
+#endif
|
|
return hash;
|
|
});
|
|
}
|
|
diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
|
|
index c0b94bf39d54..8de4900b7a01 100644
|
|
--- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
|
|
+++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
|
|
@@ -279,12 +279,15 @@ void _sparse_binary_op_intersection_kernel_impl(
|
|
if (!ptr_indices) {
|
|
return hash;
|
|
}
|
|
+#if 0
|
|
+// /usr/lib/gcc/x86_64-redhat-linux/15/../../../../include/c++/15/array:219:2: error: reference to __host__ function '__glibcxx_assert_fail' in __host__ __device__ function
|
|
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
|
|
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
|
|
const auto dim_hash_coeff = hash_coeffs[dim];
|
|
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
|
|
hash += dim_index * dim_hash_coeff;
|
|
}
|
|
+#endif
|
|
return hash;
|
|
});
|
|
}
|
|
@@ -364,6 +367,7 @@ void _sparse_binary_op_intersection_kernel_impl(
|
|
if (hash_ptr) {
|
|
hash = hash_ptr[nnz_idx];
|
|
} else if (sparse_dim) {
|
|
+#if 0
|
|
// Compute hash value
|
|
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
|
|
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
|
|
@@ -371,6 +375,7 @@ void _sparse_binary_op_intersection_kernel_impl(
|
|
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
|
|
hash += dim_index * dim_hash_coeff;
|
|
}
|
|
+#endif
|
|
}
|
|
|
|
// Perform hash values intersection
|
|
diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
|
|
index ec4c084a39cc..9bc9655b0afa 100644
|
|
--- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
|
|
+++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
|
|
@@ -341,6 +341,7 @@ void _validate_compressed_sparse_indices_kernel(
|
|
// assuming idx contiguity per batch:
|
|
int64_t tmp = batch_idx * nnz;
|
|
// `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)`
|
|
+#if 0
|
|
for (int i = idx_ndims - 1;
|
|
i >= 0 && nnz > 0; // break early when nnz==0
|
|
i--) {
|
|
@@ -348,6 +349,7 @@ void _validate_compressed_sparse_indices_kernel(
|
|
idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i];
|
|
tmp = div;
|
|
}
|
|
+#endif
|
|
const auto* RESTRICT ptr_idx_batch = ptr_idx + idx_offset;
|
|
_check_idx_sorted_distinct_vals_slices_with_cidx<
|
|
cdim_name,
|
|
--
|
|
2.48.1
|
|
|