|
2 | 2 | #include "test_tensors.h" |
3 | 3 | #include "taco/tensor.h" |
4 | 4 | #include "taco/index_notation/index_notation.h" |
| 5 | +#include "taco/index_notation/kernel.h" |
| 6 | +#include "taco/index_notation/transformations.h" |
5 | 7 |
|
6 | 8 | using namespace taco; |
7 | 9 | const IndexVar i("i"), j("j"), k("k"); |
@@ -84,4 +86,43 @@ TEST(indexstmt, spmm) { |
84 | 86 | } |
85 | 87 |
|
86 | 88 |
|
87 | | - |
| 89 | +TEST(indexstmt, sddmmPlusSpmm) { |
| 90 | + Type t(type<double>(), {3,3}); |
| 91 | + const IndexVar i("i"), j("j"), k("k"), l("l"); |
| 92 | + |
| 93 | + TensorVar A("A", t, Format{Dense, Dense}); |
| 94 | + TensorVar B("B", t, Format{Dense, Sparse}); |
| 95 | + TensorVar C("C", t, Format{Dense, Dense}); |
| 96 | + TensorVar D("D", t, Format{Dense, Dense}); |
| 97 | + TensorVar E("E", t, Format{Dense, Dense}); |
| 98 | + |
| 99 | + TensorVar tmp("tmp", Type(), Format()); |
| 100 | + |
| 101 | + // A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l) |
| 102 | + IndexStmt fused = |
| 103 | + forall(i, |
| 104 | + forall(j, |
| 105 | + forall(k, |
| 106 | + forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l)) |
| 107 | + ) |
| 108 | + ) |
| 109 | + ); |
| 110 | + |
| 111 | + std::cout << "before topological sort: " << fused << std::endl; |
| 112 | + fused = reorderLoopsTopologically(fused); |
| 113 | + std::cout << "after topological sort: " << fused << std::endl; |
| 114 | + |
| 115 | + Kernel kernel = compile(fused); |
| 116 | + |
| 117 | + IndexStmt fusedNested = |
| 118 | + forall(i, |
| 119 | + forall(j, |
| 120 | + where( |
| 121 | + forall(l, A(i,l) += tmp * E(j,l)), // consumer |
| 122 | + forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer |
| 123 | + ) |
| 124 | + ) |
| 125 | + ); |
| 126 | + |
| 127 | + std::cout << "nested loop stmt: " << fusedNested << std::endl; |
| 128 | +} |
0 commit comments