@@ -45,18 +45,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> coo2csr3d(const SparseTensor &sparse)
4545 Tensor col_ind = index.select (0 , 1 );
4646 Tensor layer_ind = index.select (0 , 2 );
4747 Tensor value = sparse.values ();
48- Tensor nnz_per_row = at::zeros ({sparse.size (0 )}, row_ind.options ());
49- nnz_per_row.scatter_add_ (0 , row_ind, at::ones_like (row_ind));
48+ // scatter_add is super slow for int64, due to non-hardware atomic operations
49+ // use int32 instead
50+ Tensor nnz_per_row = at::zeros ({sparse.size (0 )}, row_ind.options ().dtype (at::ScalarType::Int));
51+ nnz_per_row.scatter_add_ (0 , row_ind, at::ones (row_ind.sizes (), nnz_per_row.options ()));
52+ nnz_per_row = nnz_per_row.toType (at::ScalarType::Long);
5053 Tensor row_ptr = nnz_per_row.cumsum (0 ) - nnz_per_row;
5154 return std::make_tuple (row_ptr, col_ind, layer_ind, value);
5255}
5356
5457SparseTensor csr2coo3d (const Tensor &row_ptr_, const Tensor &col_ind, const Tensor &layer_ind, const Tensor &value,
5558 IntArrayRef size) {
5659 Tensor row_ptr = row_ptr_.masked_select (row_ptr_ < col_ind.size (0 ));
57- Tensor row_ind_ = at::zeros_like (col_ind);
58- row_ind_.scatter_add_ (0 , row_ptr, at::ones_like (row_ptr));
59- Tensor row_ind = row_ind_.cumsum (0 ) - 1 ;
60+ // scatter_add is super slow for int64, due to non-hardware atomic operations
61+ // use int32 instead
62+ Tensor row_ind = at::zeros (col_ind.sizes (), col_ind.options ().dtype (at::ScalarType::Int));
63+ row_ind.scatter_add_ (0 , row_ptr, at::ones (row_ptr.sizes (), row_ind.options ()));
64+ row_ind = row_ind.toType (at::ScalarType::Long);
65+ row_ind = row_ind.cumsum (0 ) - 1 ;
6066 Tensor index = at::stack ({row_ind, col_ind, layer_ind}, 0 );
6167 return at::_sparse_coo_tensor_unsafe (index, value, size, value.options ().layout (kSparse ));
6268}
0 commit comments