Skip to content

Commit 6a31a08

Browse files
committed
add recursive functionality for kernel fussion
1 parent 99adfca commit 6a31a08

File tree

2 files changed

+156
-32
lines changed

2 files changed

+156
-32
lines changed

src/index_notation/transformations.cpp

Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
118118
}
119119
})
120120
);
121+
cout << "currentOrdering: " << util::join(currentOrdering) << endl;
121122

122123
if (!content->pattern_ordered && currentOrdering == getreplacepattern()) {
123124
taco_iassert(getreplacepattern().size() == 2);
@@ -294,10 +295,27 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
294295
using IndexNotationVisitor::visit;
295296
Assignment innerAssignment;
296297
vector<IndexVar> indexAccessVars;
298+
vector<IndexVar> indexVarsUntilBranch;
299+
unsigned int pathIdx = 0;
300+
vector<int> path;
301+
302+
// insert constructor with path
303+
GetAssignment(vector<int>& _path) : path(_path) {}
297304

298305
void visit(const ForallNode* node) {
299306
Forall forall(node);
307+
cout << "Forall: " << forall << endl;
308+
cout << "pathIdx: " << pathIdx << endl;
309+
// print path
310+
cout << "path: ";
311+
for (const auto& p : path) {
312+
cout << p << " " << std::endl;
313+
}
314+
cout << endl;
300315
indexAccessVars.push_back(forall.getIndexVar());
316+
if (pathIdx < path.size()) {
317+
indexVarsUntilBranch.push_back(forall.getIndexVar());
318+
}
301319

302320
if (isa<Assignment>(forall.getStmt())) {
303321
innerAssignment = to<Assignment>(forall.getStmt());
@@ -306,8 +324,22 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
306324
IndexNotationVisitor::visit(node);
307325
}
308326
}
327+
328+
void visit(const WhereNode* node) {
329+
Where where(node);
330+
cout << "Where: " << where << endl;
331+
332+
if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer
333+
pathIdx++;
334+
IndexNotationVisitor::visit(node->producer);
335+
} else {
336+
pathIdx++;
337+
IndexNotationVisitor::visit(node->consumer);
338+
}
339+
340+
}
309341
};
310-
GetAssignment getAssignment;
342+
GetAssignment getAssignment(getPath());
311343
stmt.accept(&getAssignment);
312344

313345
std::cout << getAssignment.innerAssignment << std::endl;
@@ -322,7 +354,9 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
322354
struct GetProducerAndConsumer : public IndexNotationVisitor {
323355
using IndexNotationVisitor::visit;
324356
int pos;
357+
int pathIdx = 0;
325358
bool isProducerOnLeft;
359+
vector<int> path;
326360
IndexExpr result;
327361
IndexExpr producer;
328362
IndexExpr consumer;
@@ -332,7 +366,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
332366
map<IndexVar, pair<Type, Dimension>> varTypes;
333367
IndexExpr op;
334368

335-
GetProducerAndConsumer(int _pos, int _isProducerOnLeft) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {}
369+
GetProducerAndConsumer(int _pos, int _isProducerOnLeft, vector<int>& _path) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {}
336370

337371
void addIndexVar(Access access) {
338372
// get the dimension and type of each index variable in tensor
@@ -363,6 +397,20 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
363397
IndexNotationVisitor::visit(assignment.getRhs());
364398
}
365399

400+
void visit(const WhereNode* node) {
401+
Where where(node);
402+
cout << "Where: " << where << endl;
403+
404+
// select the path to visit
405+
if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer
406+
pathIdx++;
407+
IndexNotationVisitor::visit(node->producer);
408+
} else {
409+
pathIdx++;
410+
IndexNotationVisitor::visit(node->consumer);
411+
}
412+
}
413+
366414
// lhs is a multiplication in the tensor contraction
367415
void visit(const MulNode* node) {
368416
Mul mul(node);
@@ -390,7 +438,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
390438
pos--;
391439
}
392440
};
393-
GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft());
441+
GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft(), getPath());
394442
stmt.accept(&getProducerAndConsumer);
395443

396444
std::cout << "result: " << getProducerAndConsumer.result << std::endl;
@@ -432,6 +480,19 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
432480
consumerLoopVars.push_back(var);
433481
}
434482
}
483+
cout << "producerLoopVars2: "; printVector(producerLoopVars);
484+
cout << "consumerLoopVars2: "; printVector(consumerLoopVars);
485+
486+
// remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch
487+
cout << "indexVarsUntilBranch: "; printVector(getAssignment.indexVarsUntilBranch);
488+
for (auto& var : getAssignment.indexVarsUntilBranch) {
489+
producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end());
490+
consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end());
491+
}
492+
493+
cout << "producerLoopVars3: "; printVector(producerLoopVars);
494+
cout << "consumerLoopVars3: "; printVector(consumerLoopVars);
495+
435496

436497
// check if there are common outer loops in producerAccessOrder and consumerAccessOrder
437498
vector<IndexVar> commonLoopVars;
@@ -446,16 +507,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
446507
break;
447508
}
448509
}
449-
// for (auto& var : producerLoopVars) {
450-
// auto it = find(consumerLoopVars.begin(), consumerLoopVars.end(), var);
451-
// if (it != consumerLoopVars.end()) {
452-
// commonLoopVars.push_back(var);
453-
// temporaryVars.erase(remove(temporaryVars.begin(), temporaryVars.end(), var), temporaryVars.end());
454-
// }
455-
// else {
456-
// break;
457-
// }
458-
// }
459510
cout << "commonOuterLoops: "; printVector(commonLoopVars);
460511
cout << "temporaryVars: "; printVector(temporaryVars);
461512

@@ -479,7 +530,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
479530
}
480531
};
481532
populateDimension(getProducerAndConsumer.varTypes);
482-
TensorVar intermediateTensor("ws", Type(Float64, temporaryDims));
533+
Access resultAccess = to<Access>(getProducerAndConsumer.result);
534+
TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName(), Type(Float64, temporaryDims));
483535
Access workspace(intermediateTensor, temporaryVars);
484536
cout << "intermediateTensor: " << intermediateTensor << endl;
485537
cout << "workspace: " << workspace << endl;
@@ -503,15 +555,17 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
503555
// T(i,k) += B(i,j) * C(j,k) is the producer and A(i,j) += T(i,k) * D(k,l) is the consumer
504556
struct ProducerConsumerRewriter : public IndexNotationRewriter {
505557
using IndexNotationRewriter::visit;
558+
vector<int>& path;
559+
vector<int> visited;
506560
Assignment& producer;
507561
Assignment& consumer;
508562
vector<IndexVar>& commonLoopVars;
509563
vector<IndexVar>& producerLoopVars;
510564
vector<IndexVar>& consumerLoopVars;
511565

512566
// constructor
513-
ProducerConsumerRewriter(Assignment& producer, Assignment& consumer, vector<IndexVar>& commonLoopVars, vector<IndexVar>& producerLoopVars, vector<IndexVar>& consumerLoopVars) :
514-
producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {}
567+
ProducerConsumerRewriter(vector<int>& _path, Assignment& producer, Assignment& consumer, vector<IndexVar>& commonLoopVars, vector<IndexVar>& producerLoopVars, vector<IndexVar>& consumerLoopVars) :
568+
path(_path), producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {}
515569

516570
IndexStmt generateForalls(IndexStmt innerStmt, vector<IndexVar> indexVars) {
517571
auto returnStmt = innerStmt;
@@ -524,16 +578,38 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
524578

525579
// should find the path to get to this loop to perform the rewrite
526580
void visit(const ForallNode* node) {
527-
IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars);
528-
IndexStmt producer = generateForalls(this->producer, producerLoopVars);
529-
Where where(consumer, producer);
530-
stmt = generateForalls(where, commonLoopVars);
531-
return;
581+
if (visited == path) {
582+
IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars);
583+
IndexStmt producer = generateForalls(this->producer, producerLoopVars);
584+
Where where(consumer, producer);
585+
stmt = generateForalls(where, commonLoopVars);
586+
return;
587+
}
588+
IndexNotationRewriter::visit(node);
589+
}
590+
591+
void visit(const WhereNode* node) {
592+
Where where(node);
593+
cout << "Where: " << where << endl;
594+
595+
visited.push_back(0);
596+
IndexStmt producer = rewrite(node->producer);
597+
visited.pop_back();
598+
visited.push_back(1);
599+
IndexStmt consumer = rewrite(node->consumer);
600+
visited.pop_back();
601+
if (producer == node->producer && consumer == node->consumer) {
602+
stmt = node;
603+
}
604+
else {
605+
stmt = new WhereNode(consumer, producer);
606+
}
607+
532608
}
533609

534610
};
535611

536-
ProducerConsumerRewriter rewriter(producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars);
612+
ProducerConsumerRewriter rewriter(getPath(), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars);
537613
stmt = rewriter.rewrite(stmt);
538614
cout << "stmt: " << stmt << endl;
539615

test/tests-workspaces.cpp

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -668,36 +668,38 @@ TEST(workspaces, loopfuse) {
668668
}
669669

670670
IndexVar i("i"), j("j"), k("k"), l("l"), m("m");
671-
IndexExpr precomputedExpr = B(i,j) * C(j,k);
672-
IndexExpr precomputedExpr2 = precomputedExpr * D(k,l);
673-
// A(i,l) = precomputedExpr2;
674671
A(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m);
675672

676673
IndexStmt stmt = A.getAssignment().concretize();
677674
TensorVar ws("ws", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense});
678675
TensorVar t("t", Type(Float64, {(size_t)N, (size_t)N}), Format{Dense, Dense});
679676

680677
std::cout << stmt << endl;
681-
vector<int> path;
678+
vector<int> path1;
679+
vector<int> path2 = {0};
682680
stmt = stmt
683-
// .reorder({i,j,k,l})
684681
.reorder({i,j,k, l, m})
685-
.loopfuse(2, true, path)
682+
.loopfuse(3, true, path1)
683+
.loopfuse(2, true, path2)
684+
;
685+
stmt = stmt
686686
.parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces)
687687
;
688688

689689
stmt = stmt.concretize();
690690
cout << "final stmt: " << stmt << endl;
691-
692-
std::cout << "stmt: " << stmt << std::endl;
693691
printCodeToFile("loopfuse", stmt);
694692

695693
A.compile(stmt.concretize());
696694
A.assemble();
697695
A.compute();
698696

699-
return;
700-
697+
Tensor<double> expected("expected", {N, N}, Format{Dense, Dense});
698+
expected(i,m) = B(i,j) * C(j,k) * D(k,l) * E(l,m);
699+
expected.compile();
700+
expected.assemble();
701+
expected.compute();
702+
ASSERT_TENSOR_EQ(expected, A);
701703
}
702704

703705
TEST(workspaces, precompute2D_mul) {
@@ -962,3 +964,49 @@ TEST(workspaces, precompute_tensorContraction2) {
962964
}
963965

964966

967+
968+
TEST(workspaces, sddmmPlusSpmm) {
969+
Type t(type<double>(), {3,3});
970+
const IndexVar i("i"), j("j"), k("k"), l("l");
971+
972+
TensorVar A("A", t, Format{Dense, Dense});
973+
TensorVar B("B", t, Format{Dense, Sparse});
974+
TensorVar C("C", t, Format{Dense, Dense});
975+
TensorVar D("D", t, Format{Dense, Dense});
976+
TensorVar E("E", t, Format{Dense, Dense});
977+
978+
TensorVar tmp("tmp", Type(), Format());
979+
980+
// A(i,j) = B(i,j) * C(i,k) * D(j,k) * E(j,l)
981+
IndexStmt fused =
982+
forall(i,
983+
forall(j,
984+
forall(k,
985+
forall(l, A(i,l) += B(i,j) * C(i,k) * D(j,k) * E(j,l))
986+
)
987+
)
988+
);
989+
990+
std::cout << "before topological sort: " << fused << std::endl;
991+
fused = reorderLoopsTopologically(fused);
992+
// std::vector<IndexVar> order{"i", "j", "k", "l"};
993+
fused = fused.reorder({i, j, k, l});
994+
std::cout << "after topological sort: " << fused << std::endl;
995+
996+
// fused = fused.precompute(B(i,j) * C(i,k) * D(j,k), {}, {}, tmp);
997+
std::cout << "after precompute: " << fused << std::endl;
998+
999+
// Kernel kernel = compile(fused);
1000+
1001+
// IndexStmt fusedNested =
1002+
// forall(i,
1003+
// forall(j,
1004+
// where(
1005+
// forall(l, A(i,l) += tmp * E(j,l)), // consumer
1006+
// forall(k, tmp += B(i,j) * C(i,k) * D(j,k)) // producer
1007+
// )
1008+
// )
1009+
// );
1010+
1011+
// std::cout << "nested loop stmt: " << fusedNested << std::endl;
1012+
}

0 commit comments

Comments
 (0)