@@ -283,14 +283,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
283283 cout << endl;
284284 };
285285
286- cout << " pos: " << getPos () << std::endl;
287- cout << " isProducerOnLeft: " << getIsProducerOnLeft () << endl;
288- cout << " path: " ;
289- for (const auto & p : getPath ()) {
290- cout << p << " " << std::endl;
291- }
292- cout << endl;
293-
294286 struct GetAssignment : public IndexNotationVisitor {
295287 using IndexNotationVisitor::visit;
296288 Assignment innerAssignment;
@@ -304,14 +296,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
304296
305297 void visit (const ForallNode* node) {
306298 Forall forall (node);
307- cout << " Forall: " << forall << endl;
308- cout << " pathIdx: " << pathIdx << endl;
309- // print path
310- cout << " path: " ;
311- for (const auto & p : path) {
312- cout << p << " " << std::endl;
313- }
314- cout << endl;
299+
315300 indexAccessVars.push_back (forall.getIndexVar ());
316301 if (pathIdx < path.size ()) {
317302 indexVarsUntilBranch.push_back (forall.getIndexVar ());
@@ -327,7 +312,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
327312
328313 void visit (const WhereNode* node) {
329314 Where where (node);
330- cout << " Where: " << where << endl;
331315
332316 if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer
333317 pathIdx++;
@@ -342,9 +326,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
342326 GetAssignment getAssignment (getPath ());
343327 stmt.accept (&getAssignment);
344328
345- std::cout << getAssignment.innerAssignment << std::endl;
346- cout << " Index access order: " ; printVector (getAssignment.indexAccessVars );
347-
348329 // saves the result, producer and consumer of the assignment
349330 // result = producer * consumer
350331 // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l)
@@ -364,7 +345,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
364345 set<IndexVar> producerVars;
365346 set<IndexVar> consumerVars;
366347 map<IndexVar, pair<Type, Dimension>> varTypes;
367- IndexExpr op;
368348
369349 GetProducerAndConsumer (int _pos, int _isProducerOnLeft, vector<int >& _path) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path), result(nullptr ), producer(nullptr ), consumer(nullptr ), varTypes({}) {}
370350
@@ -383,12 +363,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
383363 // result is stored in the left hand side of the assignment
384364 result = assignment.getLhs ();
385365 resultVars = assignment.getLhs ().getIndexVars ();
386- std::cout << " result: " << result
387- << " , rhs: " << assignment.getRhs ()
388- << " , freeVars: " << assignment.getFreeVars ()
389- << " , indexVars: " << assignment.getIndexVars ()
390- << " , indexSetRelation: " << assignment.getIndexSetRel ()
391- << std::endl;
392366
393367 // add the index variables of the result to the map
394368 addIndexVar (to<Access>(assignment.getLhs ()));
@@ -420,7 +394,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
420394
421395 void visit (const AccessNode* node) {
422396 Access access (node);
423- cout << " pos: " << pos << " , access: " << access << endl;
424397 IndexExpr* it;
425398 set<IndexVar>* vars;
426399 if ((pos > 0 && isProducerOnLeft) || (pos <= 0 && !isProducerOnLeft)) { it = &producer; vars = &producerVars; }
@@ -441,13 +414,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
441414 GetProducerAndConsumer getProducerAndConsumer (getPos (), getIsProducerOnLeft (), getPath ());
442415 stmt.accept (&getProducerAndConsumer);
443416
444- std::cout << " result: " << getProducerAndConsumer.result << std::endl;
445- std::cout << " producer: " << getProducerAndConsumer.producer << std::endl;
446- std::cout << " consumer: " << getProducerAndConsumer.consumer << std::endl;
447- std::cout << " resultVars: " << getProducerAndConsumer.resultVars << std::endl;
448- cout << " producerVars: " ; printSet (getProducerAndConsumer.producerVars );
449- cout << " consumerVars: " ; printSet (getProducerAndConsumer.consumerVars );
450-
451417 // indices in the temporary comes from the producer indices (IndexVars)
452418 // that are either in result indices or in consumer indices
453419 // indices in the producer that are neither in producer indices nor in consumer indices
@@ -461,7 +427,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
461427 temporaryVars.push_back (var);
462428 }
463429 }
464- cout << " temporaryVars: " ; printVector (temporaryVars);
465430
466431 // get the producer index access pattern
467432 // get the consumer index access pattern
@@ -480,20 +445,13 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
480445 consumerLoopVars.push_back (var);
481446 }
482447 }
483- cout << " producerLoopVars2: " ; printVector (producerLoopVars);
484- cout << " consumerLoopVars2: " ; printVector (consumerLoopVars);
485448
486449 // remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch
487- cout << " indexVarsUntilBranch: " ; printVector (getAssignment.indexVarsUntilBranch );
488450 for (auto & var : getAssignment.indexVarsUntilBranch ) {
489451 producerLoopVars.erase (remove (producerLoopVars.begin (), producerLoopVars.end (), var), producerLoopVars.end ());
490452 consumerLoopVars.erase (remove (consumerLoopVars.begin (), consumerLoopVars.end (), var), consumerLoopVars.end ());
491453 }
492454
493- cout << " producerLoopVars3: " ; printVector (producerLoopVars);
494- cout << " consumerLoopVars3: " ; printVector (consumerLoopVars);
495-
496-
497455 // check if there are common outer loops in producerAccessOrder and consumerAccessOrder
498456 vector<IndexVar> commonLoopVars;
499457 for (auto & var : getAssignment.indexAccessVars ) {
@@ -507,16 +465,12 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
507465 break ;
508466 }
509467 }
510- cout << " commonOuterLoops: " ; printVector (commonLoopVars);
511- cout << " temporaryVars: " ; printVector (temporaryVars);
512468
513469 // remove commonLoopVars from producerLoopVars and consumerLoopVars
514470 for (auto & var : commonLoopVars) {
515471 producerLoopVars.erase (remove (producerLoopVars.begin (), producerLoopVars.end (), var), producerLoopVars.end ());
516472 consumerLoopVars.erase (remove (consumerLoopVars.begin (), consumerLoopVars.end (), var), consumerLoopVars.end ());
517473 }
518- cout << " producerLoopVars: " ; printVector (producerLoopVars);
519- cout << " consumerLoopVars: " ; printVector (consumerLoopVars);
520474
521475 // create the intermediate tensor
522476 vector<Dimension> temporaryDims;
@@ -533,22 +487,18 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
533487 Access resultAccess = to<Access>(getProducerAndConsumer.result );
534488 TensorVar intermediateTensor (" t_" + resultAccess.getTensorVar ().getName (), Type (Float64, temporaryDims));
535489 Access workspace (intermediateTensor, temporaryVars);
536- cout << " intermediateTensor: " << intermediateTensor << endl;
537- cout << " workspace: " << workspace << endl;
538490
539491 Assignment producerAssignment (workspace, getProducerAndConsumer.producer , getAssignment.innerAssignment .getOperator ());
540- cout << " producerAssignment: " << producerAssignment << endl;
541492
542493 Assignment consumerAssignment;
494+ // if the producer is on left, then consumer is constructed by
495+ // multiplying workspace * consumer and if the producer is on right,
496+ // then the consumer is constructed by multiplying consumer * workspace
543497 if (!getIsProducerOnLeft ()) {
544498 consumerAssignment = Assignment (to<Access>(getProducerAndConsumer.result ), getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment .getOperator ());
545499 } else {
546500 consumerAssignment = Assignment (to<Access>(getProducerAndConsumer.result ), workspace * getProducerAndConsumer.consumer , getAssignment.innerAssignment .getOperator ());
547501 }
548- cout << " consumerAssignment: " << consumerAssignment << endl;
549-
550- // check if there are common outer loops
551- // if there are common outer loops, then remove those common outer loops from the temporaryVars
552502
553503 // rewrite the index notation to use the temporary
554504 // eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l)
@@ -578,6 +528,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
578528
579529 // should find the path to get to this loop to perform the rewrite
580530 void visit (const ForallNode* node) {
531+ // at the end of the path, rewrite should happen using the producer and consumer
581532 if (visited == path) {
582533 IndexStmt consumer = generateForalls (this ->consumer , consumerLoopVars);
583534 IndexStmt producer = generateForalls (this ->producer , producerLoopVars);
@@ -590,8 +541,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
590541
591542 void visit (const WhereNode* node) {
592543 Where where (node);
593- cout << " Where: " << where << endl;
594544
545+ // add 0 to visited if the producer is visited and 1 if the consumer is visited
595546 visited.push_back (0 );
596547 IndexStmt producer = rewrite (node->producer );
597548 visited.pop_back ();
@@ -604,14 +555,11 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
604555 else {
605556 stmt = new WhereNode (consumer, producer);
606557 }
607-
608558 }
609-
610559 };
611560
612561 ProducerConsumerRewriter rewriter (getPath (), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars);
613562 stmt = rewriter.rewrite (stmt);
614- cout << " stmt: " << stmt << endl;
615563
616564 return stmt;
617565}
0 commit comments