@@ -476,10 +476,10 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr);
476476struct AccessTensorNode : public AccessNode {
477477 AccessTensorNode (TensorBase tensor, const std::vector<IndexVar>& indices)
478478 : AccessNode(tensor.getTensorVar(), indices, {}, false ),
479- tensor (tensor) {}
479+ tensorPtr (tensor.content ) {}
480480
481481 AccessTensorNode (TensorBase tensor, const std::vector<std::shared_ptr<IndexVarInterface>>& indices)
482- : AccessNode(tensor.getTensorVar()), tensor (tensor) {
482+ : AccessNode(tensor.getTensorVar()), tensorPtr (tensor.content ) {
483483 // Create the vector of IndexVar to assign to this->indexVars.
484484 std::vector<IndexVar> ivars (indices.size ());
485485 for (size_t i = 0 ; i < indices.size (); i++) {
@@ -517,10 +517,21 @@ struct AccessTensorNode : public AccessNode {
517517 this ->indexVars = std::move (ivars);
518518 }
519519
520- TensorBase tensor;
520+ // We hold a weak_ptr to the accessed TensorBase to avoid creating a reference
521+ // cycle between the accessed TensorBase and this AccessTensorNode, since the
522+ // TensorBase will store the AccessTensorNode (as part of an IndexExpr) as a
523+ // field on itself. Not using a weak pointer results in leaking TensorBases.
524+ std::weak_ptr<TensorBase::Content> tensorPtr;
525+ TensorBase getTensor () const {
526+ TensorBase tensor;
527+ tensor.content = tensorPtr.lock ();
528+ return tensor;
529+ }
530+
521531 virtual void setAssignment (const Assignment& assignment) {
522- tensor. syncDependentTensors ();
532+ auto tensor = this -> getTensor ();
523533
534+ tensor.syncDependentTensors ();
524535 Assignment assign = makeReductionNotation (assignment);
525536
526537 tensor.setNeedsPack (false );
@@ -751,7 +762,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
751762 taco_iassert (isa<AccessTensorNode>(node)) << " Unknown subexpression" ;
752763
753764 if (!util::contains (arguments, node->tensorVar )) {
754- arguments.insert ({node->tensorVar , to<AccessTensorNode>(node)->tensor });
765+ arguments.insert ({node->tensorVar , to<AccessTensorNode>(node)->getTensor () });
755766 }
756767
757768 // Also add any tensors backing index sets of tensor accesses.
@@ -763,7 +774,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
763774 }
764775
765776 // TODO (rohany): This seems like dead code.
766- TensorBase tensor = to<AccessTensorNode>(node)->tensor ;
777+ TensorBase tensor = to<AccessTensorNode>(node)->getTensor () ;
767778 if (!util::contains (inserted, tensor)) {
768779 inserted.insert (tensor);
769780 operands.push_back (tensor);
0 commit comments