Skip to content

Commit 99adfca

Browse files
committed
add a few initial test cases
1 parent 7529359 commit 99adfca

File tree

2 files changed

+405
-1
lines changed

2 files changed

+405
-1
lines changed

test/tests-indexstmt.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "test_tensors.h"
33
#include "taco/tensor.h"
44
#include "taco/index_notation/index_notation.h"
5+
#include "taco/index_notation/kernel.h"
6+
#include "taco/index_notation/transformations.h"
57

68
using namespace taco;
79
const IndexVar i("i"), j("j"), k("k");
@@ -84,4 +86,43 @@ TEST(indexstmt, spmm) {
8486
}
8587

8688

87-
89+
TEST(indexstmt, sddmmPlusSpmm) {
90+
Type t(type<double>(), {3,3});
91+
const IndexVar i("i"), j("j"), k("k"), l("l");
92+
93+
TensorVar A("A", t, Format{Dense, Dense});
94+
TensorVar B("B", t, Format{Dense, Sparse});
95+
TensorVar C("C", t, Format{Dense, Dense});
96+
TensorVar D("D", t, Format{Dense, Dense});
97+
TensorVar E("E", t, Format{Dense, Dense});
98+
99+
TensorVar tmp("tmp", Type(), Format());
100+
101+
// A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l)
102+
IndexStmt fused =
103+
forall(i,
104+
forall(j,
105+
forall(k,
106+
forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l))
107+
)
108+
)
109+
);
110+
111+
std::cout << "before topological sort: " << fused << std::endl;
112+
fused = reorderLoopsTopologically(fused);
113+
std::cout << "after topological sort: " << fused << std::endl;
114+
115+
Kernel kernel = compile(fused);
116+
117+
IndexStmt fusedNested =
118+
forall(i,
119+
forall(j,
120+
where(
121+
forall(l, A(i,l) += tmp * E(j,l)), // consumer
122+
forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer
123+
)
124+
)
125+
);
126+
127+
std::cout << "nested loop stmt: " << fusedNested << std::endl;
128+
}

0 commit comments

Comments
 (0)