Skip to content

Commit b929b53

Browse files
Merge pull request #530 from tensor-compiler/example_fixes
Use hard-coded fill values in generated code and don't emit condition…
2 parents d87ab62 + 82b83c5 commit b929b53

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

include/taco/util/strings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <map>
88
#include <iomanip>
99
#include <limits>
10+
#include <cmath>
1011

1112
// To get the value of a compiler macro variable
1213
#define STRINGIFY(x) #x
@@ -30,6 +31,9 @@ toString(const T &val) {
3031
template <class T>
3132
typename std::enable_if<std::is_floating_point<T>::value, std::string>::type
3233
toString(const T &val) {
34+
if (std::isinf(val)) {
35+
return (val < 0) ? "-INFINITY" : "INFINITY";
36+
}
3337
std::stringstream sstream;
3438
sstream << std::setprecision(std::numeric_limits<T>::max_digits10) << std::showpoint << val;
3539
return sstream.str();

src/index_notation/property_pointers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ IdentityPtr::IdentityPtr(Literal identity) : PropertyPtr(), content(new Content)
7878
}
7979

8080
IdentityPtr::IdentityPtr(Literal identity, std::vector<int> &p) : PropertyPtr(), content(new Content) {
81-
content->identity;
81+
content->identity = identity;
8282
content->positions = p;
8383
}
8484

src/lower/lowerer_impl_imperative.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,13 +1828,20 @@ Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordi
18281828
vector<Iterator> inserters;
18291829
tie(appenders, inserters) = splitAppenderAndInserters(loopLattice.results());
18301830

1831-
// If loo
1832-
if (loopLattice.iterators().size() == 1 || (loopLattice.exact() &&
1833-
isa<Assignment>(stmt) && returnsTrue(stmt.as<Assignment>().getRhs()))) {
1834-
// Just one iterator so no conditional
1831+
if (loopLattice.iterators().size() == 1) {
1832+
// Just one iterator, so no conditionals needed
18351833
taco_iassert(!loopLattice.points()[0].isOmitter());
1836-
Stmt body = lowerForallBody(coordinate, stmt, {}, inserters,
1837-
appenders, loopLattice, reducedAccesses, mergeStrategy);
1834+
Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, appenders,
1835+
loopLattice, reducedAccesses, mergeStrategy);
1836+
result.push_back(body);
1837+
}
1838+
else if (loopLattice.exact() && isa<Assignment>(stmt) &&
1839+
returnsTrue(stmt.as<Assignment>().getRhs())) {
1840+
// All cases require the same computation, so no conditionals needed
1841+
taco_iassert(!loopLattice.points()[0].isOmitter());
1842+
Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, appenders,
1843+
MergeLattice({loopLattice.points()[0]}),
1844+
reducedAccesses, mergeStrategy);
18381845
result.push_back(body);
18391846
}
18401847
else if (!loopLattice.points().empty()) {
@@ -2028,17 +2035,26 @@ Stmt LowererImplImperative::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coord
20282035
MergeStrategy mergeStrategy) {
20292036

20302037
vector<Stmt> result;
2031-
if (lattice.points().size() == 1 && lattice.iterators().size() == 1
2032-
|| (lattice.exact() &&
2033-
isa<Assignment>(stmt) && returnsTrue(stmt.as<Assignment>().getRhs()))) {
2034-
// Just one iterator so no conditional
2038+
if (lattice.points().size() == 1 && lattice.iterators().size() == 1) {
2039+
// Just one iterator, so no conditional needed
20352040
vector<Iterator> appenders;
20362041
vector<Iterator> inserters;
20372042
tie(appenders, inserters) = splitAppenderAndInserters(lattice.results());
20382043
taco_iassert(!lattice.points()[0].isOmitter());
20392044
Stmt body = lowerForallBody(coordinate, stmt, {}, inserters,
20402045
appenders, lattice, reducedAccesses, mergeStrategy);
20412046
result.push_back(body);
2047+
} else if (lattice.exact() && isa<Assignment>(stmt) &&
2048+
returnsTrue(stmt.as<Assignment>().getRhs())) {
2049+
// All cases require the same computation, so no conditionals needed
2050+
vector<Iterator> appenders;
2051+
vector<Iterator> inserters;
2052+
tie(appenders, inserters) = splitAppenderAndInserters(lattice.results());
2053+
taco_iassert(!lattice.points()[0].isOmitter());
2054+
Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, appenders,
2055+
MergeLattice({lattice.points()[0]}),
2056+
reducedAccesses, mergeStrategy);
2057+
result.push_back(body);
20422058
} else if (!lattice.points().empty()) {
20432059
map<Iterator, Expr> iteratorToConditionMap;
20442060

@@ -3046,7 +3062,7 @@ Stmt LowererImplImperative::initResultArrays(vector<Access> writes,
30463062
taco_iassert(!iterators.empty());
30473063

30483064
Expr tensor = getTensorVar(write.getTensorVar());
3049-
Expr fill = GetProperty::make(tensor, TensorProperty::FillValue);
3065+
Expr fill = lower(write.getTensorVar().getFill());
30503066
Expr valuesArr = GetProperty::make(tensor, TensorProperty::Values);
30513067
bool clearValuesAllocation = false;
30523068

@@ -3214,7 +3230,7 @@ Stmt LowererImplImperative::initResultArrays(IndexVar var, vector<Access> writes
32143230
vector<Stmt> result;
32153231
for (auto& write : writes) {
32163232
Expr tensor = getTensorVar(write.getTensorVar());
3217-
Expr fill = GetProperty::make(tensor, TensorProperty::FillValue);
3233+
Expr fill = lower(write.getTensorVar().getFill());
32183234
Expr values = GetProperty::make(tensor, TensorProperty::Values);
32193235

32203236
vector<Iterator> iterators = getIteratorsFrom(var, getIterators(write));

0 commit comments

Comments
 (0)