Skip to content

Commit 672facf

Browse files
committed
fix index merge inside a where clause
1 parent 5864011 commit 672facf

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

src/index_notation/transformations.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "taco/index_notation/index_notation_rewriter.h"
55
#include "taco/index_notation/index_notation_nodes.h"
66
#include "taco/error/error_messages.h"
7+
#include "taco/index_notation/index_notation_visitor.h"
78
#include "taco/storage/index.h"
89
#include "taco/util/collections.h"
910
#include "taco/lower/iterator.h"
@@ -288,6 +289,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
288289
Assignment innerAssignment;
289290
vector<IndexVar> indexAccessVars;
290291
vector<IndexVar> indexVarsUntilBranch;
292+
vector<IndexVar> indexVarsAfterBranch;
291293
unsigned int pathIdx = 0;
292294
vector<int> path;
293295

@@ -301,6 +303,9 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
301303
if (pathIdx < path.size()) {
302304
indexVarsUntilBranch.push_back(forall.getIndexVar());
303305
}
306+
if (pathIdx >= path.size()) {
307+
indexVarsAfterBranch.push_back(forall.getIndexVar());
308+
}
304309

305310
if (isa<Assignment>(forall.getStmt())) {
306311
innerAssignment = to<Assignment>(forall.getStmt());
@@ -454,7 +459,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
454459

455460
// check if there are common outer loops in producerAccessOrder and consumerAccessOrder
456461
vector<IndexVar> commonLoopVars;
457-
for (auto& var : getAssignment.indexAccessVars) {
462+
for (auto& var : getAssignment.indexVarsAfterBranch) {
458463
auto itC = find(consumerLoopVars.begin(), consumerLoopVars.end(), var);
459464
auto itP = find(producerLoopVars.begin(), producerLoopVars.end(), var);
460465
if (itC != consumerLoopVars.end() && itP != producerLoopVars.end()) {

test/tests-workspaces.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,61 @@ TEST(workspaces, loopfuse) {
702702
ASSERT_TENSOR_EQ(expected, A);
703703
}
704704

705+
706+
707+
TEST(workspaces, loopcontractfuse) {
708+
int N = 16;
709+
Tensor<double> A("A", {N, N, N}, Format{Dense, Dense, Dense});
710+
Tensor<double> B("B", {N, N, N}, Format{Dense, Sparse, Sparse});
711+
Tensor<double> C("C", {N, N}, Format{Dense, Dense});
712+
Tensor<double> D("D", {N, N}, Format{Dense, Dense});
713+
Tensor<double> E("E", {N, N}, Format{Dense, Dense});
714+
715+
for (int i = 0; i < N; i++) {
716+
for (int j = 0; j < N; j++) {
717+
for (int k = 0; k < N; k++) {
718+
B.insert({i, j, k}, (double) i);
719+
}
720+
C.insert({i, j}, (double) j);
721+
E.insert({i, j}, (double) i*j);
722+
D.insert({i, j}, (double) i*j);
723+
}
724+
}
725+
726+
IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n");
727+
A(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n);
728+
729+
IndexStmt stmt = A.getAssignment().concretize();
730+
731+
std::cout << stmt << endl;
732+
vector<int> path1;
733+
vector<int> path2 = {1};
734+
stmt = stmt
735+
.reorder({l,i,m, j, k, n})
736+
.loopfuse(2, true, path1)
737+
.loopfuse(2, true, path2)
738+
;
739+
stmt = stmt
740+
.parallelize(l, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces)
741+
;
742+
743+
744+
stmt = stmt.concretize();
745+
cout << "final stmt: " << stmt << endl;
746+
printCodeToFile("loopcontractfuse", stmt);
747+
748+
A.compile(stmt.concretize());
749+
A.assemble();
750+
A.compute();
751+
752+
Tensor<double> expected("expected", {N, N, N}, Format{Dense, Dense, Dense});
753+
expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n);
754+
expected.compile();
755+
expected.assemble();
756+
expected.compute();
757+
ASSERT_TENSOR_EQ(expected, A);
758+
}
759+
705760
TEST(workspaces, precompute2D_mul) {
706761
int N = 16;
707762
Tensor<double> A("A", {N, N}, Format{Dense, Dense});

0 commit comments

Comments
 (0)