Skip to content

Commit 2b8ece4

Browse files
authored
Merge pull request #526 from zhang677/test-target
Forall Context
2 parents cb00a90 + b0788bb commit 2b8ece4

File tree

7 files changed

+435
-183
lines changed

7 files changed

+435
-183
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ lib/
1010
*cmake_install.cmake
1111
CMakeCache.txt
1212
doc
13-
13+
.idea/
1414
apps/tensor_times_vector/tensor_times_vector

include/taco/index_notation/index_notation.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,35 @@ struct SuchThatNode;
7171
class IndexExprVisitorStrict;
7272
class IndexStmtVisitorStrict;
7373

74+
/// Describe the relation between indexVar sets of lhs and rhs in an Assignment node.
75+
/// equal: lhs = rhs
76+
/// none: lhs and rhs are mutually exclusive. And lhs and rhs are not empty sets.
77+
/// lcr: rhs is a proper subset of lhs. (lhs contains rhs)
78+
/// rcl: lhs is a proper subset of rhs. (rhs contains lhs)
79+
/// inter: lhs and rhs share common elements but are not equal or empty. Some examples:
80+
/// ```
81+
/// // equal
82+
/// ws(i1) += A(i1) // i1 is a child index node
83+
/// ws(i) = A(i) // i is a parent index node
84+
///
85+
/// // none
86+
/// ws(i1) += A(i) // i1 is a child of i
87+
/// B_new(i) = B(i1)
88+
///
89+
/// // lcr
90+
/// ws(i,k) = A(i) * B(i)
91+
///
92+
/// // rcl
93+
/// ws(i) += A(i,k) * B(i,k)
94+
///
95+
/// // inter
96+
/// ws(i,j) += A(i,k) * B(k,j)
97+
/// ```
98+
///
99+
enum IndexSetRel {
100+
equal, none, lcr, rcl, inter
101+
};
102+
74103
/// Return true if the index statement is of the given subtype. The subtypes
75104
/// are Assignment, Forall, Where, Sequence, and Multi.
76105
template <typename SubType> bool isa(IndexExpr);
@@ -768,6 +797,18 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
768797
IndexStmt assemble(TensorVar result, AssembleStrategy strategy,
769798
bool separately_schedulable = false) const;
770799

800+
/// The wsaccel primitive specifies the dimensions of a workspace that will be accelerated.
801+
/// Acceleration means adding compressed acceleration datastructures (bitmap, coordinate list) to a dense workspace.
802+
/// shouldAccel controls whether acceleration will be applied.
803+
/// When shouldAccel is true, if accelIndexVars is empty, then all dimensions should be accelerated.
804+
/// When shouldAccel is true, if accelIndexVars is not empty, then dimensions in accelIndexVars will be accelerated.
805+
/// When shouldAccel is false, accelIndexVars is ignored.
806+
/// Currently, it only supports one-dimension acceleration. Acceleration is used by default.
807+
///
808+
/// Precondition:
809+
/// Workspace can be accessed by the IndexVars in the accelIndexVars.
810+
IndexStmt wsaccel(TensorVar& ws, bool shouldAccel = true,const std::vector<IndexVar>& accelIndexVars ={});
811+
771812
/// Casts index statement to specified subtype.
772813
template <typename SubType>
773814
SubType as() {
@@ -820,6 +861,9 @@ class Assignment : public IndexStmt {
820861
/// Return the reduction index variables i nthe assign
821862
std::vector<IndexVar> getReductionVars() const;
822863

864+
/// Return the set relation of indexVars in lhs and rhs
865+
IndexSetRel getIndexSetRel() const;
866+
823867
typedef AssignmentNode Node;
824868
};
825869

@@ -1143,6 +1187,15 @@ class TensorVar : public util::Comparable<TensorVar> {
11431187
/// Gets the fill value of the tensor variable. May be left undefined.
11441188
const Literal& getFill() const;
11451189

1190+
/// Gets the acceleration dimensions
1191+
const std::vector<IndexVar>& getAccelIndexVars() const;
1192+
1193+
/// Gets the acceleration flag
1194+
bool getShouldAccel() const;
1195+
1196+
/// Set the acceleration dimensions
1197+
void setAccelIndexVars(const std::vector<IndexVar>& accelIndexVars, bool shouldAccel);
1198+
11461199
/// Set the fill value of the tensor variable
11471200
void setFill(const Literal& fill);
11481201

src/index_notation/index_notation.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,6 +2048,32 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
20482048
return transformed;
20492049
}
20502050

2051+
IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
2052+
if (accelIndexVars.size() == 0) {
2053+
ws.setAccelIndexVars(accelIndexVars, shouldAccel);
2054+
return *this;
2055+
}
2056+
set<IndexVar> TempVars;
2057+
match(*this,
2058+
std::function<void(const WhereNode*,Matcher*)>([&](const WhereNode* where,Matcher* ctx) {
2059+
auto Temp = getResultAccesses(where->producer).first[0];
2060+
if (Temp.getTensorVar() == ws) {
2061+
for (auto i :getIndexVars()){
2062+
TempVars.insert(i);
2063+
}
2064+
}
2065+
ctx->match(where->producer);
2066+
ctx->match(where->consumer);
2067+
}));
2068+
for (auto i : accelIndexVars) {
2069+
if (TempVars.find(i) == TempVars.end()) {
2070+
taco_uerror << "No matching indexVars in the Accel";
2071+
}
2072+
}
2073+
ws.setAccelIndexVars(accelIndexVars, shouldAccel);
2074+
return *this;
2075+
}
2076+
20512077
std::ostream& operator<<(std::ostream& os, const IndexStmt& expr) {
20522078
if (!expr.defined()) return os << "IndexStmt()";
20532079
IndexNotationPrinter printer(os);
@@ -2102,6 +2128,50 @@ std::vector<IndexVar> Assignment::getReductionVars() const {
21022128
return reductionVars;
21032129
}
21042130

2131+
IndexSetRel Assignment::getIndexSetRel() const {
2132+
vector<IndexVar> freeVars = getLhs().getIndexVars();
2133+
set<IndexVar> lseen(freeVars.begin(), freeVars.end());
2134+
vector<IndexVar> RVars ;
2135+
match(getRhs(),
2136+
std::function<void(const AccessNode*)>([&](const AccessNode* op) {
2137+
for (auto& var : op->indexVars) {
2138+
RVars.push_back(var);
2139+
}
2140+
}));
2141+
set<IndexVar> rseen(RVars.begin(), RVars.end());
2142+
IndexSetRel rel = equal;
2143+
std::vector<IndexVar> v_inter;
2144+
int lnum = lseen.size();
2145+
int rnum = rseen.size();
2146+
int rcl_num = 0;
2147+
for (auto & var : rseen){
2148+
if (util::contains(lseen, var)) {
2149+
rcl_num += 1;
2150+
}
2151+
}
2152+
if (rcl_num == 0) {
2153+
rel = none;
2154+
}
2155+
else if ((rcl_num<lnum) && (rcl_num == rnum)){
2156+
rel = lcr;
2157+
}
2158+
else if ((rcl_num<lnum) && (rcl_num<rnum)){
2159+
rel = inter;
2160+
} else if ((rcl_num == lnum) && (rcl_num == rnum)){
2161+
rel = equal;
2162+
} else if ((rcl_num == lnum) && (rcl_num<rnum)) {
2163+
rel = rcl;
2164+
}
2165+
else {
2166+
rel = none;
2167+
}
2168+
2169+
if (lnum == 0 && rel == none) {
2170+
rel = rcl;
2171+
}
2172+
return rel;
2173+
}
2174+
21052175
template <> bool isa<Assignment>(IndexStmt s) {
21062176
return isa<AssignmentNode>(s.ptr);
21072177
}
@@ -2476,6 +2546,8 @@ struct TensorVar::Content {
24762546
Format format;
24772547
Schedule schedule;
24782548
Literal fill;
2549+
std::vector<IndexVar> accelIndexVars;
2550+
bool shouldAccel;
24792551
};
24802552

24812553
TensorVar::TensorVar() : content(nullptr) {
@@ -2508,6 +2580,8 @@ TensorVar::TensorVar(const int& id, const string& name, const Type& type, const
25082580
content->type = type;
25092581
content->format = format;
25102582
content->fill = fill.defined()? fill : Literal::zero(type.getDataType());
2583+
content->accelIndexVars = std::vector<IndexVar> {};
2584+
content->shouldAccel = true;
25112585
}
25122586

25132587
int TensorVar::getId() const {
@@ -2551,6 +2625,19 @@ const Literal& TensorVar::getFill() const {
25512625
return content->fill;
25522626
}
25532627

2628+
const std::vector<IndexVar>& TensorVar::getAccelIndexVars() const {
2629+
return content->accelIndexVars;
2630+
}
2631+
2632+
bool TensorVar::getShouldAccel() const {
2633+
return content->shouldAccel;
2634+
}
2635+
2636+
void TensorVar::setAccelIndexVars(const std::vector<IndexVar>& accelIndexVars, bool shouldAccel) {
2637+
content->shouldAccel = shouldAccel;
2638+
content->accelIndexVars = accelIndexVars;
2639+
}
2640+
25542641
void TensorVar::setFill(const Literal &fill) {
25552642
content->fill = fill;
25562643
}

0 commit comments

Comments
 (0)