Skip to content

Commit 7529359

Browse files
committed
initial implementation of kernel fuse directive
1 parent 2b8ece4 commit 7529359

File tree

8 files changed

+395
-0
lines changed

8 files changed

+395
-0
lines changed

include/taco/index_notation/index_notation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,11 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
646646
IndexStmt divide(IndexVar i, IndexVar i1, IndexVar i2, size_t divideFactor) const; // TODO: TailStrategy
647647

648648

649+
/// The loopfuse transformation fuses common outer loops in
650+
/// 2 iteration graphs.
651+
IndexStmt loopfuse(int pos, bool isProducerOnLeft, std::vector<int>& path) const;
652+
653+
649654
/// The reorder transformation swaps two directly nested index
650655
/// variables in an iteration graph. This changes the order of
651656
/// iteration through the space and the order of tensor accesses.

include/taco/index_notation/transformations.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class IndexStmt;
1717
class TransformationInterface;
1818
class Reorder;
1919
class Precompute;
20+
class LoopFuse;
2021
class ForAllReplace;
2122
class AddSuchThatPredicates;
2223
class Parallelize;
@@ -32,6 +33,7 @@ class Transformation {
3233
public:
3334
Transformation(Reorder);
3435
Transformation(Precompute);
36+
Transformation(LoopFuse);
3537
Transformation(ForAllReplace);
3638
Transformation(Parallelize);
3739
Transformation(TopoReorder);
@@ -114,6 +116,28 @@ class Precompute : public TransformationInterface {
114116
/// Print a precompute command.
115117
std::ostream &operator<<(std::ostream &, const Precompute &);
116118

119+
/// The loopfuse optimization rewrite an index expression to precompute
120+
/// part of the `expr` and store it to a workspace.
121+
class LoopFuse : public TransformationInterface {
122+
public:
123+
LoopFuse();
124+
LoopFuse(int pos, bool isProducerOnLeft, std::vector<int>& path);
125+
126+
int getPos() const;
127+
bool getIsProducerOnLeft() const;
128+
std::vector<int>& getPath() const;
129+
130+
/// Apply the loopfuse optimization to a concrete index statement.
131+
IndexStmt apply(IndexStmt, std::string *reason = nullptr) const;
132+
133+
void print(std::ostream &os) const;
134+
135+
private:
136+
struct Content;
137+
std::shared_ptr<Content> content;
138+
};
139+
140+
std::ostream &operator<<(std::ostream &, const LoopFuse &);
117141

118142
/// Replaces all occurrences of directly nested forall nodes of pattern with
119143
/// directly nested loops of replacement

include/taco/parser/lexer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ enum class Token {
2222
sub,
2323
mul,
2424
div,
25+
colon,
2526
eq,
2627
eot, // End of tokens
2728
error

src/index_notation/index_notation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,6 +1854,26 @@ IndexStmt IndexStmt::divide(IndexVar i, IndexVar i1, IndexVar i2, size_t splitFa
18541854
return transformed;
18551855
}
18561856

1857+
IndexStmt IndexStmt::loopfuse(int pos, bool isProducerOnLeft, vector<int>& path) const {
1858+
1859+
std::cout << "Loop fuse pos: " << pos;
1860+
std::cout << ", Loop fuse isProducerOnLeft: " << isProducerOnLeft;
1861+
for (const auto& p : path) {
1862+
std::cout << " " << p;
1863+
}
1864+
std::cout << std::endl;
1865+
1866+
string reason;
1867+
IndexStmt transformed = *this;
1868+
transformed = Transformation(LoopFuse(pos, isProducerOnLeft, path)).apply(transformed, &reason);
1869+
if (!transformed.defined()) {
1870+
taco_uerror << reason;
1871+
}
1872+
return transformed;
1873+
1874+
return *this;
1875+
}
1876+
18571877
IndexStmt IndexStmt::precompute(IndexExpr expr, std::vector<IndexVar> i_vars,
18581878
std::vector<IndexVar> iw_vars, TensorVar workspace) const {
18591879

@@ -2048,6 +2068,7 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
20482068
return transformed;
20492069
}
20502070

2071+
20512072
IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
20522073
if (accelIndexVars.size() == 0) {
20532074
ws.setAccelIndexVars(accelIndexVars, shouldAccel);

0 commit comments

Comments
 (0)