Skip to content

Commit 0474698

Browse files
committed
speed up spmm & rspmm
1 parent 70a7091 commit 0474698

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

torchdrug/layers/functional/extension/rspmm.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> coo2csr3d(const SparseTensor &sparse)
4545
Tensor col_ind = index.select(0, 1);
4646
Tensor layer_ind = index.select(0, 2);
4747
Tensor value = sparse.values();
48-
Tensor nnz_per_row = at::zeros({sparse.size(0)}, row_ind.options());
49-
nnz_per_row.scatter_add_(0, row_ind, at::ones_like(row_ind));
48+
// scatter_add is super slow for int64, due to non-hardware atomic operations
49+
// use int32 instead
50+
Tensor nnz_per_row = at::zeros({sparse.size(0)}, row_ind.options().dtype(at::ScalarType::Int));
51+
nnz_per_row.scatter_add_(0, row_ind, at::ones(row_ind.sizes(), nnz_per_row.options()));
52+
nnz_per_row = nnz_per_row.toType(at::ScalarType::Long);
5053
Tensor row_ptr = nnz_per_row.cumsum(0) - nnz_per_row;
5154
return std::make_tuple(row_ptr, col_ind, layer_ind, value);
5255
}
5356

5457
SparseTensor csr2coo3d(const Tensor &row_ptr_, const Tensor &col_ind, const Tensor &layer_ind, const Tensor &value,
5558
IntArrayRef size) {
5659
Tensor row_ptr = row_ptr_.masked_select(row_ptr_ < col_ind.size(0));
57-
Tensor row_ind_ = at::zeros_like(col_ind);
58-
row_ind_.scatter_add_(0, row_ptr, at::ones_like(row_ptr));
59-
Tensor row_ind = row_ind_.cumsum(0) - 1;
60+
// scatter_add is super slow for int64, due to non-hardware atomic operations
61+
// use int32 instead
62+
Tensor row_ind = at::zeros(col_ind.sizes(), col_ind.options().dtype(at::ScalarType::Int));
63+
row_ind.scatter_add_(0, row_ptr, at::ones(row_ptr.sizes(), row_ind.options()));
64+
row_ind = row_ind.toType(at::ScalarType::Long);
65+
row_ind = row_ind.cumsum(0) - 1;
6066
Tensor index = at::stack({row_ind, col_ind, layer_ind}, 0);
6167
return at::_sparse_coo_tensor_unsafe(index, value, size, value.options().layout(kSparse));
6268
}

torchdrug/layers/functional/extension/spmm.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,23 @@ std::tuple<Tensor, Tensor, Tensor> coo2csr(const SparseTensor &sparse) {
4040
Tensor row_ind = index.select(0, 0);
4141
Tensor col_ind = index.select(0, 1);
4242
Tensor value = sparse.values();
43-
Tensor nnz_per_row = at::zeros({sparse.size(0)}, row_ind.options());
44-
nnz_per_row.scatter_add_(0, row_ind, at::ones_like(row_ind));
43+
// scatter_add is super slow for int64, due to non-hardware atomic operations
44+
// use int32 instead
45+
Tensor nnz_per_row = at::zeros({sparse.size(0)}, row_ind.options().dtype(at::ScalarType::Int));
46+
nnz_per_row.scatter_add_(0, row_ind, at::ones(row_ind.sizes(), nnz_per_row.options()));
47+
nnz_per_row = nnz_per_row.toType(at::ScalarType::Long);
4548
Tensor row_ptr = nnz_per_row.cumsum(0) - nnz_per_row;
4649
return std::make_tuple(row_ptr, col_ind, value);
4750
}
4851

4952
SparseTensor csr2coo(const Tensor &row_ptr_, const Tensor &col_ind, const Tensor &value, IntArrayRef size) {
5053
Tensor row_ptr = row_ptr_.masked_select(row_ptr_ < col_ind.size(0));
51-
Tensor row_ind_ = at::zeros_like(col_ind);
52-
row_ind_.scatter_add_(0, row_ptr, at::ones_like(row_ptr));
53-
Tensor row_ind = row_ind_.cumsum(0) - 1;
54+
// scatter_add is super slow for int64, due to non-hardware atomic operations
55+
// use int32 instead
56+
Tensor row_ind = at::zeros(col_ind.sizes(), col_ind.options().dtype(at::ScalarType::Int));
57+
row_ind.scatter_add_(0, row_ptr, at::ones(row_ptr.sizes(), row_ind.options()));
58+
row_ind = row_ind.toType(at::ScalarType::Long);
59+
row_ind = row_ind.cumsum(0) - 1;
5460
Tensor index = at::stack({row_ind, col_ind}, 0);
5561
return at::_sparse_coo_tensor_unsafe(index, value, size, value.options().layout(kSparse));
5662
}

0 commit comments

Comments
 (0)