Skip to content

Commit c3caf52

Browse files
committed
add reordering for index stmt with branches
1 parent 672facf commit c3caf52

File tree

5 files changed

+185
-8
lines changed

5 files changed

+185
-8
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,9 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
668668
/// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order
669669
IndexStmt reorder(std::vector<IndexVar> reorderedvars) const;
670670

671+
/// reorders the index variables in a nested structure with where clauses
672+
IndexStmt reorder(std::vector<int> path, std::vector<IndexVar> reorderedvars) const;
673+
671674
/// The mergeby transformation specifies how to merge iterators on
672675
/// the given index variable. By default, if an iterator is used for windowing
673676
/// it will be merged with the "gallop" strategy.

include/taco/index_notation/transformations.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ class Reorder : public TransformationInterface {
6767
public:
6868
Reorder(IndexVar i, IndexVar j);
6969
Reorder(std::vector<IndexVar> replacePattern);
70+
Reorder(std::vector<int> path, std::vector<IndexVar> replacePattern);
7071

7172
IndexVar geti() const;
7273
IndexVar getj() const;
7374
const std::vector<IndexVar>& getreplacepattern() const;
75+
const std::vector<int>& getpath() const;
7476

7577
/// Apply the reorder optimization to a concrete index statement. Returns
7678
/// an undefined statement and a reason if the statement cannot be lowered.

src/index_notation/index_notation.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,17 @@ IndexStmt IndexStmt::reorder(std::vector<IndexVar> reorderedvars) const {
19271927
return transformed;
19281928
}
19291929

1930+
IndexStmt IndexStmt::reorder(std::vector<int> path, std::vector<IndexVar> reorderedvars) const {
1931+
string reason;
1932+
cout << "Index statement path: " << util::join(path) << endl;
1933+
cout << "Index statement reorderedvars: " << reorderedvars << endl;
1934+
IndexStmt transformed = Reorder(path, reorderedvars).apply(*this, &reason);
1935+
if (!transformed.defined()) {
1936+
taco_uerror << reason;
1937+
}
1938+
return transformed;
1939+
}
1940+
19301941
IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const {
19311942
string reason;
19321943
IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason);

src/index_notation/transformations.cpp

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ std::ostream& operator<<(std::ostream& os, const Transformation& t) {
6464

6565
// class Reorder
6666
struct Reorder::Content {
67+
std::vector<int> path;
6768
std::vector<IndexVar> replacePattern;
6869
bool pattern_ordered; // In case of Reorder(i, j) need to change replacePattern ordering to actually reorder
6970
};
@@ -78,6 +79,12 @@ Reorder::Reorder(std::vector<taco::IndexVar> replacePattern) : content(new Conte
7879
content->pattern_ordered = true;
7980
}
8081

82+
Reorder::Reorder(std::vector<int> path, std::vector<taco::IndexVar> replacePattern) : content(new Content) {
83+
content->path = path;
84+
content->replacePattern = replacePattern;
85+
content->pattern_ordered = true;
86+
}
87+
8188
IndexVar Reorder::geti() const {
8289
return content->replacePattern[0];
8390
}
@@ -93,13 +100,66 @@ const std::vector<IndexVar>& Reorder::getreplacepattern() const {
93100
return content->replacePattern;
94101
}
95102

103+
const std::vector<int>& Reorder::getpath() const {
104+
return content->path;
105+
}
106+
96107
IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
97108
INIT_REASON(reason);
98109

99110
string r;
100-
if (!isConcreteNotation(stmt, &r)) {
101-
*reason = "The index statement is not valid concrete index notation: " + r;
102-
return IndexStmt();
111+
112+
// TODO - Add a different check for concrete index notation with branching
113+
// if (!isConcreteNotation(stmt, &r)) {
114+
// *reason = "The index statement is not valid concrete index notation: " + r;
115+
// return IndexStmt();
116+
// }
117+
118+
IndexStmt originalStmt = stmt;
119+
struct ReorderVisitor : public IndexNotationVisitor {
120+
using IndexNotationVisitor::visit;
121+
vector<int>& path;
122+
unsigned int pathIdx = 0;
123+
IndexStmt innerStmt;
124+
125+
ReorderVisitor(vector<int>& path) : path(path) {}
126+
127+
void visit(const ForallNode* node) {
128+
if (pathIdx == path.size()) {
129+
innerStmt = IndexStmt(node);
130+
return;
131+
}
132+
IndexNotationVisitor::visit(node);
133+
}
134+
135+
void visit(const WhereNode* node) {
136+
137+
Where where(node);
138+
139+
if (pathIdx == path.size()) {
140+
innerStmt = IndexStmt(node);
141+
return;
142+
}
143+
144+
if (!path[pathIdx]) {
145+
pathIdx++;
146+
IndexNotationVisitor::visit(node->producer);
147+
} else {
148+
pathIdx++;
149+
IndexNotationVisitor::visit(node->consumer);
150+
}
151+
}
152+
};
153+
154+
cout << "original statement: " << originalStmt << endl;
155+
ReorderVisitor reorderVisitor(content->path);
156+
157+
auto p = getpath();
158+
cout << "path: " << util::join(p) << endl;
159+
if (p.size() > 0) {
160+
originalStmt.accept(&reorderVisitor);
161+
cout << "reordering statment: " << reorderVisitor.innerStmt << endl;
162+
stmt = reorderVisitor.innerStmt;
103163
}
104164

105165
// collect current ordering of IndexVars
@@ -130,7 +190,52 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
130190
*reason = "The foralls of reorder pattern: " + util::join(getreplacepattern()) + " were not directly nested.";
131191
return IndexStmt();
132192
}
133-
return ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason);
193+
194+
cout << "replacePattern: " << util::join(getreplacepattern()) << endl;
195+
auto reorderedStmt = ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason);
196+
197+
198+
struct ReorderedRewriter : public IndexNotationRewriter {
199+
using IndexNotationRewriter::visit;
200+
201+
IndexStmt reorderedStmt;
202+
vector<int>& path;
203+
vector<int> visited;
204+
205+
ReorderedRewriter(IndexStmt reorderedStmt, vector<int>& path) : reorderedStmt(reorderedStmt), path(path) {}
206+
207+
void visit(const ForallNode* node) {
208+
// at the end of the path, rewrite should happen using the producer and consumer
209+
if (visited == path) {
210+
stmt = reorderedStmt;
211+
return;
212+
}
213+
IndexNotationRewriter::visit(node);
214+
}
215+
216+
void visit(const WhereNode* node) {
217+
Where where(node);
218+
219+
// add 0 to visited if the producer is visited and 1 if the consumer is visited
220+
visited.push_back(0);
221+
IndexStmt producer = rewrite(node->producer);
222+
visited.pop_back();
223+
visited.push_back(1);
224+
IndexStmt consumer = rewrite(node->consumer);
225+
visited.pop_back();
226+
if (producer == node->producer && consumer == node->consumer) {
227+
stmt = node;
228+
}
229+
else {
230+
stmt = new WhereNode(consumer, producer);
231+
}
232+
233+
}
234+
};
235+
ReorderedRewriter reorderedRewriter(reorderedStmt, content->path);
236+
stmt = reorderedRewriter.rewrite(originalStmt);
237+
238+
return stmt;
134239
}
135240

136241
void Reorder::print(std::ostream& os) const {
@@ -1068,10 +1173,12 @@ IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const {
10681173
INIT_REASON(reason);
10691174

10701175
string r;
1071-
if (!isConcreteNotation(stmt, &r)) {
1072-
*reason = "The index statement is not valid concrete index notation: " + r;
1073-
return IndexStmt();
1074-
}
1176+
1177+
// TODO - Add a different check for concrete index notation with branching
1178+
// if (!isConcreteNotation(stmt, &r)) {
1179+
// *reason = "The index statement is not valid concrete index notation: " + r;
1180+
// return IndexStmt();
1181+
// }
10751182

10761183
/// Since all IndexVars can only appear once, assume replacement will work and error if it doesn't
10771184
struct ForAllReplaceRewriter : public IndexNotationRewriter {

test/tests-workspaces.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,60 @@ TEST(workspaces, loopcontractfuse) {
757757
ASSERT_TENSOR_EQ(expected, A);
758758
}
759759

760+
TEST(workspaces, loopreordercontractfuse) {
761+
int N = 16;
762+
Tensor<double> A("A", {N, N, N}, Format{Dense, Dense, Dense});
763+
Tensor<double> B("B", {N, N, N}, Format{Dense, Sparse, Sparse});
764+
Tensor<double> C("C", {N, N}, Format{Dense, Dense});
765+
Tensor<double> D("D", {N, N}, Format{Dense, Dense});
766+
Tensor<double> E("E", {N, N}, Format{Dense, Dense});
767+
768+
for (int i = 0; i < N; i++) {
769+
for (int j = 0; j < N; j++) {
770+
for (int k = 0; k < N; k++) {
771+
B.insert({i, j, k}, (double) i);
772+
}
773+
C.insert({i, j}, (double) j);
774+
E.insert({i, j}, (double) i*j);
775+
D.insert({i, j}, (double) i*j);
776+
}
777+
}
778+
779+
IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n");
780+
A(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n);
781+
782+
IndexStmt stmt = A.getAssignment().concretize();
783+
784+
std::cout << stmt << endl;
785+
vector<int> path1;
786+
vector<int> path2 = {1};
787+
stmt = stmt
788+
.reorder({l,i,m, j, k, n})
789+
.loopfuse(2, true, path1)
790+
.reorder(path2, {m,k,j,n})
791+
.loopfuse(2, true, path2)
792+
;
793+
stmt = stmt
794+
.parallelize(l, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces)
795+
;
796+
797+
798+
stmt = stmt.concretize();
799+
cout << "final stmt: " << stmt << endl;
800+
printCodeToFile("loopreordercontractfuse", stmt);
801+
802+
A.compile(stmt.concretize());
803+
A.assemble();
804+
A.compute();
805+
806+
Tensor<double> expected("expected", {N, N, N}, Format{Dense, Dense, Dense});
807+
expected(l,m,n) = B(i,j,k) * C(i,l) * D(j,m) * E(k,n);
808+
expected.compile();
809+
expected.assemble();
810+
expected.compute();
811+
ASSERT_TENSOR_EQ(expected, A);
812+
}
813+
760814
TEST(workspaces, precompute2D_mul) {
761815
int N = 16;
762816
Tensor<double> A("A", {N, N}, Format{Dense, Dense});

0 commit comments

Comments
 (0)