@@ -1828,13 +1828,20 @@ Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordi
18281828 vector<Iterator> inserters;
18291829 tie (appenders, inserters) = splitAppenderAndInserters (loopLattice.results ());
18301830
1831- // If loo
1832- if (loopLattice.iterators ().size () == 1 || (loopLattice.exact () &&
1833- isa<Assignment>(stmt) && returnsTrue (stmt.as <Assignment>().getRhs ()))) {
1834- // Just one iterator so no conditional
1831+ if (loopLattice.iterators ().size () == 1 ) {
1832+ // Just one iterator, so no conditionals needed
18351833 taco_iassert (!loopLattice.points ()[0 ].isOmitter ());
1836- Stmt body = lowerForallBody (coordinate, stmt, {}, inserters,
1837- appenders, loopLattice, reducedAccesses, mergeStrategy);
1834+ Stmt body = lowerForallBody (coordinate, stmt, {}, inserters, appenders,
1835+ loopLattice, reducedAccesses, mergeStrategy);
1836+ result.push_back (body);
1837+ }
1838+ else if (loopLattice.exact () && isa<Assignment>(stmt) &&
1839+ returnsTrue (stmt.as <Assignment>().getRhs ())) {
1840+ // All cases require the same computation, so no conditionals needed
1841+ taco_iassert (!loopLattice.points ()[0 ].isOmitter ());
1842+ Stmt body = lowerForallBody (coordinate, stmt, {}, inserters, appenders,
1843+ MergeLattice ({loopLattice.points ()[0 ]}),
1844+ reducedAccesses, mergeStrategy);
18381845 result.push_back (body);
18391846 }
18401847 else if (!loopLattice.points ().empty ()) {
@@ -2028,17 +2035,26 @@ Stmt LowererImplImperative::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coord
20282035 MergeStrategy mergeStrategy) {
20292036
20302037 vector<Stmt> result;
2031- if (lattice.points ().size () == 1 && lattice.iterators ().size () == 1
2032- || (lattice.exact () &&
2033- isa<Assignment>(stmt) && returnsTrue (stmt.as <Assignment>().getRhs ()))) {
2034- // Just one iterator so no conditional
2038+ if (lattice.points ().size () == 1 && lattice.iterators ().size () == 1 ) {
2039+ // Just one iterator, so no conditional needed
20352040 vector<Iterator> appenders;
20362041 vector<Iterator> inserters;
20372042 tie (appenders, inserters) = splitAppenderAndInserters (lattice.results ());
20382043 taco_iassert (!lattice.points ()[0 ].isOmitter ());
20392044 Stmt body = lowerForallBody (coordinate, stmt, {}, inserters,
20402045 appenders, lattice, reducedAccesses, mergeStrategy);
20412046 result.push_back (body);
2047+ } else if (lattice.exact () && isa<Assignment>(stmt) &&
2048+ returnsTrue (stmt.as <Assignment>().getRhs ())) {
2049+ // All cases require the same computation, so no conditionals needed
2050+ vector<Iterator> appenders;
2051+ vector<Iterator> inserters;
2052+ tie (appenders, inserters) = splitAppenderAndInserters (lattice.results ());
2053+ taco_iassert (!lattice.points ()[0 ].isOmitter ());
2054+ Stmt body = lowerForallBody (coordinate, stmt, {}, inserters, appenders,
2055+ MergeLattice ({lattice.points ()[0 ]}),
2056+ reducedAccesses, mergeStrategy);
2057+ result.push_back (body);
20422058 } else if (!lattice.points ().empty ()) {
20432059 map<Iterator, Expr> iteratorToConditionMap;
20442060
@@ -3046,7 +3062,7 @@ Stmt LowererImplImperative::initResultArrays(vector<Access> writes,
30463062 taco_iassert (!iterators.empty ());
30473063
30483064 Expr tensor = getTensorVar (write.getTensorVar ());
3049- Expr fill = GetProperty::make (tensor, TensorProperty::FillValue );
3065+ Expr fill = lower (write. getTensorVar (). getFill () );
30503066 Expr valuesArr = GetProperty::make (tensor, TensorProperty::Values);
30513067 bool clearValuesAllocation = false ;
30523068
@@ -3214,7 +3230,7 @@ Stmt LowererImplImperative::initResultArrays(IndexVar var, vector<Access> writes
32143230 vector<Stmt> result;
32153231 for (auto & write : writes) {
32163232 Expr tensor = getTensorVar (write.getTensorVar ());
3217- Expr fill = GetProperty::make (tensor, TensorProperty::FillValue );
3233+ Expr fill = lower (write. getTensorVar (). getFill () );
32183234 Expr values = GetProperty::make (tensor, TensorProperty::Values);
32193235
32203236 vector<Iterator> iterators = getIteratorsFrom (var, getIterators (write));
0 commit comments