From f646e0f04ae591c8f2d8a0cd24b035725c57659b Mon Sep 17 00:00:00 2001 From: Tom Rix Date: Thu, 23 Jan 2025 08:24:22 -0800 Subject: [PATCH] torch: paper over c++ assert --- aten/src/ATen/native/sparse/FlattenIndicesCommon.h | 2 ++ .../ATen/native/sparse/SparseBinaryOpIntersectionCommon.h | 5 +++++ .../src/ATen/native/sparse/ValidateCompressedIndicesCommon.h | 2 ++ 3 files changed, 9 insertions(+) diff --git a/aten/src/ATen/native/sparse/FlattenIndicesCommon.h b/aten/src/ATen/native/sparse/FlattenIndicesCommon.h index 0e79ed809ae6..a3cec8aaf78b 100644 --- a/aten/src/ATen/native/sparse/FlattenIndicesCommon.h +++ b/aten/src/ATen/native/sparse/FlattenIndicesCommon.h @@ -69,11 +69,13 @@ Tensor _flatten_indices_impl(const Tensor& indices, IntArrayRef size) { [=] FUNCAPI (int64_t nnz_idx) -> int64_t { const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride; auto hash = static_cast(0); +#if 0 for (int64_t dim = 0; dim < sparse_dim; ++dim) { const auto dim_hash_coeff = hash_coeffs[dim]; const auto dim_index = ptr_indices_dim[dim * indices_dim_stride]; hash += dim_index * dim_hash_coeff; } +#endif return hash; }); } diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index c0b94bf39d54..8de4900b7a01 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -279,12 +279,15 @@ void _sparse_binary_op_intersection_kernel_impl( if (!ptr_indices) { return hash; } +#if 0 +// /usr/lib/gcc/x86_64-redhat-linux/15/../../../../include/c++/15/array:219:2: error: reference to __host__ function '__glibcxx_assert_fail' in __host__ __device__ function const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride; for (int64_t dim = 0; dim < sparse_dim; ++dim) { const auto dim_hash_coeff = hash_coeffs[dim]; const auto dim_index = ptr_indices_dim[dim * indices_dim_stride]; hash += dim_index * dim_hash_coeff; } +#endif return hash; }); } @@ -364,6 +367,7 @@ void _sparse_binary_op_intersection_kernel_impl( if (hash_ptr) { hash = hash_ptr[nnz_idx]; } else if (sparse_dim) { +#if 0 // Compute hash value const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride; for (int64_t dim = 0; dim < sparse_dim; ++dim) { @@ -371,6 +375,7 @@ void _sparse_binary_op_intersection_kernel_impl( const auto dim_index = ptr_indices_dim[dim * indices_dim_stride]; hash += dim_index * dim_hash_coeff; } +#endif } // Perform hash values intersection diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h index ec4c084a39cc..9bc9655b0afa 100644 --- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h @@ -341,6 +341,7 @@ void _validate_compressed_sparse_indices_kernel( // assuming idx contiguity per batch: int64_t tmp = batch_idx * nnz; // `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)` +#if 0 for (int i = idx_ndims - 1; i >= 0 && nnz > 0; // break early when nnz==0 i--) { @@ -348,6 +349,7 @@ void _validate_compressed_sparse_indices_kernel( idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i]; tmp = div; } +#endif const auto* RESTRICT ptr_idx_batch = ptr_idx + idx_offset; _check_idx_sorted_distinct_vals_slices_with_cidx< cdim_name, -- 2.48.1