Skip to content

Commit 2479221

Browse files
committed
remove comments
1 parent 6a31a08 commit 2479221

File tree

1 file changed

+6
-58
lines changed

1 file changed

+6
-58
lines changed

src/index_notation/transformations.cpp

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
283283
cout << endl;
284284
};
285285

286-
cout << "pos: " << getPos() << std::endl;
287-
cout << "isProducerOnLeft: " << getIsProducerOnLeft() << endl;
288-
cout << "path: ";
289-
for (const auto& p : getPath()) {
290-
cout << p << " " << std::endl;
291-
}
292-
cout << endl;
293-
294286
struct GetAssignment : public IndexNotationVisitor {
295287
using IndexNotationVisitor::visit;
296288
Assignment innerAssignment;
@@ -304,14 +296,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
304296

305297
void visit(const ForallNode* node) {
306298
Forall forall(node);
307-
cout << "Forall: " << forall << endl;
308-
cout << "pathIdx: " << pathIdx << endl;
309-
// print path
310-
cout << "path: ";
311-
for (const auto& p : path) {
312-
cout << p << " " << std::endl;
313-
}
314-
cout << endl;
299+
315300
indexAccessVars.push_back(forall.getIndexVar());
316301
if (pathIdx < path.size()) {
317302
indexVarsUntilBranch.push_back(forall.getIndexVar());
@@ -327,7 +312,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
327312

328313
void visit(const WhereNode* node) {
329314
Where where(node);
330-
cout << "Where: " << where << endl;
331315

332316
if (!path[pathIdx]) { // if path[pathIdx] == 0, go to the producer
333317
pathIdx++;
@@ -342,9 +326,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
342326
GetAssignment getAssignment(getPath());
343327
stmt.accept(&getAssignment);
344328

345-
std::cout << getAssignment.innerAssignment << std::endl;
346-
cout << "Index access order: "; printVector(getAssignment.indexAccessVars);
347-
348329
// saves the result, producer and consumer of the assignment
349330
// result = producer * consumer
350331
// eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l)
@@ -364,7 +345,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
364345
set<IndexVar> producerVars;
365346
set<IndexVar> consumerVars;
366347
map<IndexVar, pair<Type, Dimension>> varTypes;
367-
IndexExpr op;
368348

369349
GetProducerAndConsumer(int _pos, int _isProducerOnLeft, vector<int>& _path) : pos(_pos), isProducerOnLeft(_isProducerOnLeft), path(_path), result(nullptr), producer(nullptr), consumer(nullptr), varTypes({}) {}
370350

@@ -383,12 +363,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
383363
// result is stored in the left hand side of the assignment
384364
result = assignment.getLhs();
385365
resultVars = assignment.getLhs().getIndexVars();
386-
std::cout << "result: " << result
387-
<< ", rhs: " << assignment.getRhs()
388-
<< ", freeVars: " << assignment.getFreeVars()
389-
<< ", indexVars: " << assignment.getIndexVars()
390-
<< ", indexSetRelation: " << assignment.getIndexSetRel()
391-
<< std::endl;
392366

393367
// add the index variables of the result to the map
394368
addIndexVar(to<Access>(assignment.getLhs()));
@@ -420,7 +394,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
420394

421395
void visit(const AccessNode* node) {
422396
Access access(node);
423-
cout << "pos: " << pos << ", access: " << access << endl;
424397
IndexExpr* it;
425398
set<IndexVar>* vars;
426399
if ((pos > 0 && isProducerOnLeft) || (pos <= 0 && !isProducerOnLeft)) { it = &producer; vars = &producerVars; }
@@ -441,13 +414,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
441414
GetProducerAndConsumer getProducerAndConsumer(getPos(), getIsProducerOnLeft(), getPath());
442415
stmt.accept(&getProducerAndConsumer);
443416

444-
std::cout << "result: " << getProducerAndConsumer.result << std::endl;
445-
std::cout << "producer: " << getProducerAndConsumer.producer << std::endl;
446-
std::cout << "consumer: " << getProducerAndConsumer.consumer << std::endl;
447-
std::cout << "resultVars: " << getProducerAndConsumer.resultVars << std::endl;
448-
cout << "producerVars: "; printSet(getProducerAndConsumer.producerVars);
449-
cout << "consumerVars: "; printSet(getProducerAndConsumer.consumerVars);
450-
451417
// indices in the temporary comes from the producer indices (IndexVars)
452418
// that are either in result indices or in consumer indices
453419
// indices in the producer that are neither in producer indices nor in consumer indices
@@ -461,7 +427,6 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
461427
temporaryVars.push_back(var);
462428
}
463429
}
464-
cout << "temporaryVars: "; printVector(temporaryVars);
465430

466431
// get the producer index access pattern
467432
// get the consumer index access pattern
@@ -480,20 +445,13 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
480445
consumerLoopVars.push_back(var);
481446
}
482447
}
483-
cout << "producerLoopVars2: "; printVector(producerLoopVars);
484-
cout << "consumerLoopVars2: "; printVector(consumerLoopVars);
485448

486449
// remove indices from producerLoops and consumerLoops that are in getAssignment.indexVarsUntilBranch
487-
cout << "indexVarsUntilBranch: "; printVector(getAssignment.indexVarsUntilBranch);
488450
for (auto& var : getAssignment.indexVarsUntilBranch) {
489451
producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end());
490452
consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end());
491453
}
492454

493-
cout << "producerLoopVars3: "; printVector(producerLoopVars);
494-
cout << "consumerLoopVars3: "; printVector(consumerLoopVars);
495-
496-
497455
// check if there are common outer loops in producerAccessOrder and consumerAccessOrder
498456
vector<IndexVar> commonLoopVars;
499457
for (auto& var : getAssignment.indexAccessVars) {
@@ -507,16 +465,12 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
507465
break;
508466
}
509467
}
510-
cout << "commonOuterLoops: "; printVector(commonLoopVars);
511-
cout << "temporaryVars: "; printVector(temporaryVars);
512468

513469
// remove commonLoopVars from producerLoopVars and consumerLoopVars
514470
for (auto& var : commonLoopVars) {
515471
producerLoopVars.erase(remove(producerLoopVars.begin(), producerLoopVars.end(), var), producerLoopVars.end());
516472
consumerLoopVars.erase(remove(consumerLoopVars.begin(), consumerLoopVars.end(), var), consumerLoopVars.end());
517473
}
518-
cout << "producerLoopVars: "; printVector(producerLoopVars);
519-
cout << "consumerLoopVars: "; printVector(consumerLoopVars);
520474

521475
// create the intermediate tensor
522476
vector<Dimension> temporaryDims;
@@ -533,22 +487,18 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
533487
Access resultAccess = to<Access>(getProducerAndConsumer.result);
534488
TensorVar intermediateTensor("t_" + resultAccess.getTensorVar().getName(), Type(Float64, temporaryDims));
535489
Access workspace(intermediateTensor, temporaryVars);
536-
cout << "intermediateTensor: " << intermediateTensor << endl;
537-
cout << "workspace: " << workspace << endl;
538490

539491
Assignment producerAssignment(workspace, getProducerAndConsumer.producer, getAssignment.innerAssignment.getOperator());
540-
cout << "producerAssignment: " << producerAssignment << endl;
541492

542493
Assignment consumerAssignment;
494+
// if the producer is on left, then consumer is constructed by
495+
// multiplying workspace * consumer and if the producer is on right,
496+
// then the consumer is constructed by multiplying consumer * workspace
543497
if (!getIsProducerOnLeft()) {
544498
consumerAssignment = Assignment(to<Access>(getProducerAndConsumer.result), getProducerAndConsumer.consumer * workspace, getAssignment.innerAssignment.getOperator());
545499
} else {
546500
consumerAssignment = Assignment(to<Access>(getProducerAndConsumer.result), workspace * getProducerAndConsumer.consumer, getAssignment.innerAssignment.getOperator());
547501
}
548-
cout << "consumerAssignment: " << consumerAssignment << endl;
549-
550-
// check if there are common outer loops
551-
// if there are common outer loops, then remove those common outer loops from the temporaryVars
552502

553503
// rewrite the index notation to use the temporary
554504
// eg: Assignment is A(i,j) += B(i,j) * C(j,k) * D(k,l)
@@ -578,6 +528,7 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
578528

579529
// should find the path to get to this loop to perform the rewrite
580530
void visit(const ForallNode* node) {
531+
// at the end of the path, rewrite should happen using the producer and consumer
581532
if (visited == path) {
582533
IndexStmt consumer = generateForalls(this->consumer, consumerLoopVars);
583534
IndexStmt producer = generateForalls(this->producer, producerLoopVars);
@@ -590,8 +541,8 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
590541

591542
void visit(const WhereNode* node) {
592543
Where where(node);
593-
cout << "Where: " << where << endl;
594544

545+
// add 0 to visited if the producer is visited and 1 if the consumer is visited
595546
visited.push_back(0);
596547
IndexStmt producer = rewrite(node->producer);
597548
visited.pop_back();
@@ -604,14 +555,11 @@ IndexStmt LoopFuse::apply(IndexStmt stmt, std::string* reason) const {
604555
else {
605556
stmt = new WhereNode(consumer, producer);
606557
}
607-
608558
}
609-
610559
};
611560

612561
ProducerConsumerRewriter rewriter(getPath(), producerAssignment, consumerAssignment, commonLoopVars, producerLoopVars, consumerLoopVars);
613562
stmt = rewriter.rewrite(stmt);
614-
cout << "stmt: " << stmt << endl;
615563

616564
return stmt;
617565
}

0 commit comments

Comments
 (0)