@@ -64,6 +64,7 @@ std::ostream& operator<<(std::ostream& os, const Transformation& t) {
6464
6565// class Reorder
6666struct Reorder ::Content {
67+ std::vector<int > path;
6768 std::vector<IndexVar> replacePattern;
6869 bool pattern_ordered; // In case of Reorder(i, j) need to change replacePattern ordering to actually reorder
6970};
@@ -78,6 +79,12 @@ Reorder::Reorder(std::vector<taco::IndexVar> replacePattern) : content(new Conte
7879 content->pattern_ordered = true ;
7980}
8081
82+ Reorder::Reorder (std::vector<int > path, std::vector<taco::IndexVar> replacePattern) : content(new Content) {
83+ content->path = path;
84+ content->replacePattern = replacePattern;
85+ content->pattern_ordered = true ;
86+ }
87+
8188IndexVar Reorder::geti () const {
8289 return content->replacePattern [0 ];
8390}
@@ -93,13 +100,66 @@ const std::vector<IndexVar>& Reorder::getreplacepattern() const {
93100 return content->replacePattern ;
94101}
95102
103+ const std::vector<int >& Reorder::getpath () const {
104+ return content->path ;
105+ }
106+
96107IndexStmt Reorder::apply (IndexStmt stmt, string* reason) const {
97108 INIT_REASON (reason);
98109
99110 string r;
100- if (!isConcreteNotation (stmt, &r)) {
101- *reason = " The index statement is not valid concrete index notation: " + r;
102- return IndexStmt ();
111+
112+ // TODO - Add a different check for concrete index notation with branching
113+ // if (!isConcreteNotation(stmt, &r)) {
114+ // *reason = "The index statement is not valid concrete index notation: " + r;
115+ // return IndexStmt();
116+ // }
117+
118+ IndexStmt originalStmt = stmt;
119+ struct ReorderVisitor : public IndexNotationVisitor {
120+ using IndexNotationVisitor::visit;
121+ vector<int >& path;
122+ unsigned int pathIdx = 0 ;
123+ IndexStmt innerStmt;
124+
125+ ReorderVisitor (vector<int >& path) : path(path) {}
126+
127+ void visit (const ForallNode* node) {
128+ if (pathIdx == path.size ()) {
129+ innerStmt = IndexStmt (node);
130+ return ;
131+ }
132+ IndexNotationVisitor::visit (node);
133+ }
134+
135+ void visit (const WhereNode* node) {
136+
137+ Where where (node);
138+
139+ if (pathIdx == path.size ()) {
140+ innerStmt = IndexStmt (node);
141+ return ;
142+ }
143+
144+ if (!path[pathIdx]) {
145+ pathIdx++;
146+ IndexNotationVisitor::visit (node->producer );
147+ } else {
148+ pathIdx++;
149+ IndexNotationVisitor::visit (node->consumer );
150+ }
151+ }
152+ };
153+
154+ cout << " original statement: " << originalStmt << endl;
155+ ReorderVisitor reorderVisitor (content->path );
156+
157+ auto p = getpath ();
158+ cout << " path: " << util::join (p) << endl;
159+ if (p.size () > 0 ) {
160+ originalStmt.accept (&reorderVisitor);
161+ cout << " reordering statment: " << reorderVisitor.innerStmt << endl;
162+ stmt = reorderVisitor.innerStmt ;
103163 }
104164
105165 // collect current ordering of IndexVars
@@ -130,7 +190,52 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
130190 *reason = " The foralls of reorder pattern: " + util::join (getreplacepattern ()) + " were not directly nested." ;
131191 return IndexStmt ();
132192 }
133- return ForAllReplace (currentOrdering, getreplacepattern ()).apply (stmt, reason);
193+
194+ cout << " replacePattern: " << util::join (getreplacepattern ()) << endl;
195+ auto reorderedStmt = ForAllReplace (currentOrdering, getreplacepattern ()).apply (stmt, reason);
196+
197+
198+ struct ReorderedRewriter : public IndexNotationRewriter {
199+ using IndexNotationRewriter::visit;
200+
201+ IndexStmt reorderedStmt;
202+ vector<int >& path;
203+ vector<int > visited;
204+
205+ ReorderedRewriter (IndexStmt reorderedStmt, vector<int >& path) : reorderedStmt(reorderedStmt), path(path) {}
206+
207+ void visit (const ForallNode* node) {
208+ // at the end of the path, rewrite should happen using the producer and consumer
209+ if (visited == path) {
210+ stmt = reorderedStmt;
211+ return ;
212+ }
213+ IndexNotationRewriter::visit (node);
214+ }
215+
216+ void visit (const WhereNode* node) {
217+ Where where (node);
218+
219+ // add 0 to visited if the producer is visited and 1 if the consumer is visited
220+ visited.push_back (0 );
221+ IndexStmt producer = rewrite (node->producer );
222+ visited.pop_back ();
223+ visited.push_back (1 );
224+ IndexStmt consumer = rewrite (node->consumer );
225+ visited.pop_back ();
226+ if (producer == node->producer && consumer == node->consumer ) {
227+ stmt = node;
228+ }
229+ else {
230+ stmt = new WhereNode (consumer, producer);
231+ }
232+
233+ }
234+ };
235+ ReorderedRewriter reorderedRewriter (reorderedStmt, content->path );
236+ stmt = reorderedRewriter.rewrite (originalStmt);
237+
238+ return stmt;
134239}
135240
136241void Reorder::print (std::ostream& os) const {
@@ -1068,10 +1173,12 @@ IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const {
10681173 INIT_REASON (reason);
10691174
10701175 string r;
1071- if (!isConcreteNotation (stmt, &r)) {
1072- *reason = " The index statement is not valid concrete index notation: " + r;
1073- return IndexStmt ();
1074- }
1176+
1177+ // TODO - Add a different check for concrete index notation with branching
1178+ // if (!isConcreteNotation(stmt, &r)) {
1179+ // *reason = "The index statement is not valid concrete index notation: " + r;
1180+ // return IndexStmt();
1181+ // }
10751182
10761183 // / Since all IndexVars can only appear once, assume replacement will work and error if it doesn't
10771184 struct ForAllReplaceRewriter : public IndexNotationRewriter {
0 commit comments