@@ -2048,6 +2048,32 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
20482048 return transformed;
20492049}
20502050
2051+ IndexStmt IndexStmt::wsaccel (TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
2052+ if (accelIndexVars.size () == 0 ) {
2053+ ws.setAccelIndexVars (accelIndexVars, shouldAccel);
2054+ return *this ;
2055+ }
2056+ set<IndexVar> TempVars;
2057+ match (*this ,
2058+ std::function<void (const WhereNode*,Matcher*)>([&](const WhereNode* where,Matcher* ctx) {
2059+ auto Temp = getResultAccesses (where->producer ).first [0 ];
2060+ if (Temp.getTensorVar () == ws) {
2061+ for (auto i :getIndexVars ()){
2062+ TempVars.insert (i);
2063+ }
2064+ }
2065+ ctx->match (where->producer );
2066+ ctx->match (where->consumer );
2067+ }));
2068+ for (auto i : accelIndexVars) {
2069+ if (TempVars.find (i) == TempVars.end ()) {
2070+ taco_uerror << " No matching indexVars in the Accel" ;
2071+ }
2072+ }
2073+ ws.setAccelIndexVars (accelIndexVars, shouldAccel);
2074+ return *this ;
2075+ }
2076+
20512077std::ostream& operator <<(std::ostream& os, const IndexStmt& expr) {
20522078 if (!expr.defined ()) return os << " IndexStmt()" ;
20532079 IndexNotationPrinter printer (os);
@@ -2102,6 +2128,50 @@ std::vector<IndexVar> Assignment::getReductionVars() const {
21022128 return reductionVars;
21032129}
21042130
2131+ IndexSetRel Assignment::getIndexSetRel () const {
2132+ vector<IndexVar> freeVars = getLhs ().getIndexVars ();
2133+ set<IndexVar> lseen (freeVars.begin (), freeVars.end ());
2134+ vector<IndexVar> RVars ;
2135+ match (getRhs (),
2136+ std::function<void (const AccessNode*)>([&](const AccessNode* op) {
2137+ for (auto & var : op->indexVars ) {
2138+ RVars.push_back (var);
2139+ }
2140+ }));
2141+ set<IndexVar> rseen (RVars.begin (), RVars.end ());
2142+ IndexSetRel rel = equal;
2143+ std::vector<IndexVar> v_inter;
2144+ int lnum = lseen.size ();
2145+ int rnum = rseen.size ();
2146+ int rcl_num = 0 ;
2147+ for (auto & var : rseen){
2148+ if (util::contains (lseen, var)) {
2149+ rcl_num += 1 ;
2150+ }
2151+ }
2152+ if (rcl_num == 0 ) {
2153+ rel = none;
2154+ }
2155+ else if ((rcl_num<lnum) && (rcl_num == rnum)){
2156+ rel = lcr;
2157+ }
2158+ else if ((rcl_num<lnum) && (rcl_num<rnum)){
2159+ rel = inter;
2160+ } else if ((rcl_num == lnum) && (rcl_num == rnum)){
2161+ rel = equal;
2162+ } else if ((rcl_num == lnum) && (rcl_num<rnum)) {
2163+ rel = rcl;
2164+ }
2165+ else {
2166+ rel = none;
2167+ }
2168+
2169+ if (lnum == 0 && rel == none) {
2170+ rel = rcl;
2171+ }
2172+ return rel;
2173+ }
2174+
21052175template <> bool isa<Assignment>(IndexStmt s) {
21062176 return isa<AssignmentNode>(s.ptr );
21072177}
@@ -2476,6 +2546,8 @@ struct TensorVar::Content {
24762546 Format format;
24772547 Schedule schedule;
24782548 Literal fill;
2549+ std::vector<IndexVar> accelIndexVars;
2550+ bool shouldAccel;
24792551};
24802552
24812553TensorVar::TensorVar () : content(nullptr ) {
@@ -2508,6 +2580,8 @@ TensorVar::TensorVar(const int& id, const string& name, const Type& type, const
25082580 content->type = type;
25092581 content->format = format;
25102582 content->fill = fill.defined ()? fill : Literal::zero (type.getDataType ());
2583+ content->accelIndexVars = std::vector<IndexVar> {};
2584+ content->shouldAccel = true ;
25112585}
25122586
25132587int TensorVar::getId () const {
@@ -2551,6 +2625,19 @@ const Literal& TensorVar::getFill() const {
25512625 return content->fill ;
25522626}
25532627
2628+ const std::vector<IndexVar>& TensorVar::getAccelIndexVars () const {
2629+ return content->accelIndexVars ;
2630+ }
2631+
2632+ bool TensorVar::getShouldAccel () const {
2633+ return content->shouldAccel ;
2634+ }
2635+
2636+ void TensorVar::setAccelIndexVars (const std::vector<IndexVar>& accelIndexVars, bool shouldAccel) {
2637+ content->shouldAccel = shouldAccel;
2638+ content->accelIndexVars = accelIndexVars;
2639+ }
2640+
25542641void TensorVar::setFill (const Literal &fill) {
25552642 content->fill = fill;
25562643}
0 commit comments