@@ -118,6 +118,7 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
118118 }
119119 })
120120 );
121+ cout << " currentOrdering: " << util::join (currentOrdering) << endl;
121122
122123 if (!content->pattern_ordered && currentOrdering == getreplacepattern ()) {
123124 taco_iassert (getreplacepattern ().size () == 2 );
@@ -294,10 +295,27 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
294295 using IndexNotationVisitor::visit;
295296 Assignment innerAssignment;
296297 vector<IndexVar> indexAccessVars;
298+ vector<IndexVar> indexVarsUntilBranch;
299+ unsigned int pathIdx = 0 ;
300+ vector<int > path;
301+
302+ // insert constructor with path
303+ GetAssignment (vector<int >& _path) : path(_path) {}
297304
298305 void visit (const ForallNode* node) {
299306 Forall forall (node);
307+ cout << " Forall: " << forall << endl;
308+ cout << " pathIdx: " << pathIdx << endl;
309+ // print path
310+ cout << " path: " ;
311+ for (const auto & p : path) {
312+ cout << p << " " << std::endl;
313+ }
314+ cout << endl;
300315 indexAccessVars.push_back (forall.getIndexVar ());
316+ if (pathIdx < path.size ()) {
317+ indexVarsUntilBranch.push_back (forall.getIndexVar ());
318+ }
301319
302320 if (isa<Assignment>(forall.getStmt ())) {
303321 innerAssignment = to<Assignment>(forall.getStmt ());
@@ -306,8 +324,22 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
306324 IndexNotationVisitor::visit (node);
307325 }
308326 }
327+
328+ void visit (const WhereNode* node) {
329+ Where where (node);
330+ cout << " Where: " << where << endl;
331+
332+ if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer
333+ pathIdx++;
334+ IndexNotationVisitor::visit (node->producer );
335+ } else {
336+ pathIdx++;
337+ IndexNotationVisitor::visit (node->consumer );
338+ }
339+
340+ }
309341 };
310- GetAssignment getAssignment;
342+ GetAssignment getAssignment ( getPath ()) ;
311343 stmt.accept (&getAssignment);
312344
313345 std::cout << getAssignment.innerAssignment << std::endl;
@@ -322,7 +354,9 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
322354 struct GetProducerAndConsumer : public IndexNotationVisitor {
323355 using IndexNotationVisitor::visit;
324356 int pos;
357+ int pathIdx = 0 ;
325358 bool isProducerOnLeft;
359+ vector<int > path;
326360 IndexExpr result;
327361 IndexExpr producer;
328362 IndexExpr consumer;
@@ -332,7 +366,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
332366 map<IndexVar, pair<Type, Dimension>> varTypes;
333367 IndexExpr op;
334368
335- GetProducerAndConsumer (int _pos, int _isProducerOnLeft) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), result(nullptr ), producer(nullptr ), consumer(nullptr ), varTypes({}) {}
369+ GetProducerAndConsumer (int _pos, int _isProducerOnLeft, vector< int >& _path ) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path ), result(nullptr ), producer(nullptr ), consumer(nullptr ), varTypes({}) {}
336370
337371 void addIndexVar (Access access) {
338372 // get the dimension and type of each index variable in tensor
@@ -363,6 +397,20 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
363397 IndexNotationVisitor::visit (assignment.getRhs ());
364398 }
365399
400+ void visit (const WhereNode* node) {
401+ Where where (node);
402+ cout << " Where: " << where << endl;
403+
404+ // select the path to visit
405+ if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer
406+ pathIdx++;
407+ IndexNotationVisitor::visit (node->producer );
408+ } else {
409+ pathIdx++;
410+ IndexNotationVisitor::visit (node->consumer );
411+ }
412+ }
413+
366414 // lhs is a multiplication in the tensor contraction
367415 void visit (const MulNode* node) {
368416 Mul mul (node);
@@ -390,7 +438,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
390438 pos--;
391439 }
392440 };
393- GetProducerAndConsumer getProducerAndConsumer (getPos (), getIsProducerOnLeft ());
441+ GetProducerAndConsumer getProducerAndConsumer (getPos (), getIsProducerOnLeft (), getPath () );
394442 stmt.accept (&getProducerAndConsumer);
395443
396444 std::cout << " result: " << getProducerAndConsumer.result << std::endl;
@@ -432,6 +480,19 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
432480 consumerLoopVars.push_back (var);
433481 }
434482 }
483+ cout << " producerLoopVars2: " ; printVector (producerLoopVars);
484+ cout << " consumerLoopVars2: " ; printVector (consumerLoopVars);
485+
486+ // remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch
487+ cout << " indexVarsUntilBranch: " ; printVector (getAssignment.indexVarsUntilBranch );
488+ for (auto & var : getAssignment.indexVarsUntilBranch ) {
489+ producerLoopVars.erase (remove (producerLoopVars.begin (), producerLoopVars.end (), var), producerLoopVars.end ());
490+ consumerLoopVars.erase (remove (consumerLoopVars.begin (), consumerLoopVars.end (), var), consumerLoopVars.end ());
491+ }
492+
493+ cout << " producerLoopVars3: " ; printVector (producerLoopVars);
494+ cout << " consumerLoopVars3: " ; printVector (consumerLoopVars);
495+
435496
436497 // check if there are common outer loops in producerAccessOrder and consumerAccessOrder
437498 vector<IndexVar> commonLoopVars;
@@ -446,16 +507,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
446507 break ;
447508 }
448509 }
449- // for (auto& var : producerLoopVars) {
450- // auto it = find(consumerLoopVars.begin(), consumerLoopVars.end(), var);
451- // if (it != consumerLoopVars.end()) {
452- // commonLoopVars.push_back(var);
453- // temporaryVars.erase(remove(temporaryVars.begin(), temporaryVars.end(), var), temporaryVars.end());
454- // }
455- // else {
456- // break;
457- // }
458- // }
459510 cout << " commonOuterLoops: " ; printVector (commonLoopVars);
460511 cout << " temporaryVars: " ; printVector (temporaryVars);
461512
@@ -479,7 +530,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
479530 }
480531 };
481532 populateDimension (getProducerAndConsumer.varTypes );
482- TensorVar intermediateTensor (" ws" , Type (Float64, temporaryDims));
533+ Access resultAccess = to<Access>(getProducerAndConsumer.result );
534+ TensorVar intermediateTensor (" t_" + resultAccess.getTensorVar ().getName (), Type (Float64, temporaryDims));
483535 Access workspace (intermediateTensor, temporaryVars);
484536 cout << " intermediateTensor: " << intermediateTensor << endl;
485537 cout << " workspace: " << workspace << endl;
@@ -503,15 +555,17 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
503555 // T(i,k) += B(i,j) * C(j,k) is the producer and A(i,j) += T(i,k) * D(k,l) is the consumer
504556 struct ProducerConsumerRewriter : public IndexNotationRewriter {
505557 using IndexNotationRewriter::visit;
558+ vector<int >& path;
559+ vector<int > visited;
506560 Assignment& producer;
507561 Assignment& consumer;
508562 vector<IndexVar>& commonLoopVars;
509563 vector<IndexVar>& producerLoopVars;
510564 vector<IndexVar>& consumerLoopVars;
511565
512566 // constructor
513- ProducerConsumerRewriter (Assignment& producer, Assignment& consumer, vector<IndexVar>& commonLoopVars, vector<IndexVar>& producerLoopVars, vector<IndexVar>& consumerLoopVars) :
514- producer (producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {}
567+ ProducerConsumerRewriter (vector< int >& _path, Assignment& producer, Assignment& consumer, vector<IndexVar>& commonLoopVars, vector<IndexVar>& producerLoopVars, vector<IndexVar>& consumerLoopVars) :
568+ path (_path), producer(producer), consumer(consumer), commonLoopVars(commonLoopVars), producerLoopVars(producerLoopVars), consumerLoopVars(consumerLoopVars) {}
515569
516570 IndexStmt generateForalls (IndexStmt innerStmt, vector<IndexVar> indexVars) {
517571 auto returnStmt = innerStmt;
@@ -524,16 +578,38 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
524578
525579 // should find the path to get to this loop to perform the rewrite
526580 void visit (const ForallNode* node) {
527- IndexStmt consumer = generateForalls (this ->consumer , consumerLoopVars);
528- IndexStmt producer = generateForalls (this ->producer , producerLoopVars);
529- Where where (consumer, producer);
530- stmt = generateForalls (where, commonLoopVars);
531- return ;
581+ if (visited == path) {
582+ IndexStmt consumer = generateForalls (this ->consumer , consumerLoopVars);
583+ IndexStmt producer = generateForalls (this ->producer , producerLoopVars);
584+ Where where (consumer, producer);
585+ stmt = generateForalls (where, commonLoopVars);
586+ return ;
587+ }
588+ IndexNotationRewriter::visit (node);
589+ }
590+
591+ void visit (const WhereNode* node) {
592+ Where where (node);
593+ cout << " Where: " << where << endl;
594+
595+ visited.push_back (0 );
596+ IndexStmt producer = rewrite (node->producer );
597+ visited.pop_back ();
598+ visited.push_back (1 );
599+ IndexStmt consumer = rewrite (node->consumer );
600+ visited.pop_back ();
601+ if (producer == node->producer && consumer == node->consumer ) {
602+ stmt = node;
603+ }
604+ else {
605+ stmt = new WhereNode (consumer, producer);
606+ }
607+
532608 }
533609
534610 };
535611
536- ProducerConsumerRewriter rewriter (producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars);
612+ ProducerConsumerRewriter rewriter (getPath (), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars);
537613 stmt = rewriter.rewrite (stmt);
538614 cout << " stmt: " << stmt << endl;
539615
0 commit comments