@@ -437,10 +437,16 @@ TEST(workspaces, precompute3D_renamedIVars_TspV) {
437437
438438TEST (workspaces, DISABLED_tile_dotProduct_1) {
439439 // FIXME: Disabled because currently the precompute algorithm does not appropriately
440- // optimize = from += when rewriting a statement for BOTH the producer and consumer
441- // side of a where statement insertion.
442- // Although always using += is CORRECT functionally, this fails the GPU tests since it
443- // would result in scattering.
440+ // find the correct forall substmt to next the WhereNode in after i has been
441+ // split into i0 and i1. As an example, the first precompute below is incorrect
442+ // since it should transform
443+ // forall(i0, forall(i1, A() += B(i) * C(i))) -->
444+ // forall(i0, where(forall(i1, A() += ws(i1)), forall(i1, ws(i1) += B(i) * C(i))))
445+ //
446+ // But currently the algorithm does
447+ // forall(i0, forall(i1, A() += B(i) * C(i))) -->
448+ // where(forall(i1, A() += ws(i1)), forall(i0, forall(i1, ws(i1) += B(i) * C(i))))
449+
444450 int N = 1024 ;
445451 Tensor<double > A (" A" );
446452 Tensor<double > B (" B" , {N}, Format ({Dense}));
@@ -470,17 +476,26 @@ TEST(workspaces, DISABLED_tile_dotProduct_1) {
470476 stmt = stmt.bound (i, i_bounded, (size_t )N, BoundType::MaxExact)
471477 .split (i_bounded, i0, i1, 32 );
472478 stmt = stmt.precompute (precomputedExpr, i1, i1, precomputed);
473-
474- stmt = stmt.precompute (BExpr, i1, i1, B_new)
475- .precompute (CExpr, i1, i1, C_new);
476-
477-
479+ stmt = stmt.precompute (BExpr, i1, i1, B_new)
480+ .precompute (CExpr, i1, i1, C_new);
481+
478482 stmt = stmt.concretize ();
479483
480484 A.compile (stmt);
481485 A.assemble ();
482486 A.compute ();
483487
488+ ir::IRPrinter irp = ir::IRPrinter (cout);
489+
490+ cout << stmt << endl;
491+
492+ std::shared_ptr<ir::CodeGen> codegen = ir::CodeGen::init_default (cout, ir::CodeGen::ImplementationGen);
493+ ir::Stmt compute = lower (stmt, " compute" , false , true );
494+
495+ irp.print (compute);
496+ cout << endl;
497+ codegen->compile (compute, false );
498+
484499 Tensor<double > expected (" expected" );
485500 expected () = B (i) * C (i);
486501 expected.compile ();
@@ -543,3 +558,51 @@ TEST(workspaces, DISABLED_tile_dotProduct_2) {
543558 ASSERT_TENSOR_EQ (expected, A);
544559}
545560
561+ TEST (workspaces, tile_dotProduct_3) {
562+ int N = 1024 ;
563+ Tensor<double > A (" A" );
564+ Tensor<double > B (" B" , {N}, Format ({Dense}));
565+ Tensor<double > C (" C" , {N}, Format ({Dense}));
566+
567+ for (int i = 0 ; i < N; i++) {
568+ B.insert ({i}, (double ) i);
569+ C.insert ({i}, (double ) i);
570+ }
571+
572+ B.pack ();
573+ C.pack ();
574+
575+ IndexVar i (" i" );
576+ IndexVar i_bounded (" i_bounded" );
577+ IndexVar i0 (" i0" ), i1 (" i1" );
578+ IndexExpr BExpr = B (i);
579+ IndexExpr CExpr = C (i);
580+ IndexExpr precomputedExpr = (BExpr) * (CExpr);
581+ A () = precomputedExpr;
582+
583+ IndexStmt stmt = A.getAssignment ().concretize ();
584+ TensorVar B_new (" B_new" , Type (Float64, {(size_t )N}), taco::dense);
585+ TensorVar C_new (" C_new" , Type (Float64, {(size_t )N}), taco::dense);
586+ TensorVar precomputed (" precomputed" , Type (Float64, {(size_t )N}), taco::dense);
587+
588+ stmt = stmt.bound (i, i_bounded, (size_t )N, BoundType::MaxExact)
589+ .split (i_bounded, i0, i1, 32 );
590+ stmt = stmt.precompute (precomputedExpr, i0, i0, precomputed);
591+
592+ stmt = stmt.precompute (BExpr, i1, i1, B_new)
593+ .precompute (CExpr, i1, i1, C_new);
594+
595+
596+ stmt = stmt.concretize ();
597+
598+ A.compile (stmt);
599+ A.assemble ();
600+ A.compute ();
601+
602+ Tensor<double > expected (" expected" );
603+ expected () = B (i) * C (i);
604+ expected.compile ();
605+ expected.assemble ();
606+ expected.compute ();
607+ ASSERT_TENSOR_EQ (expected, A);
608+ }
0 commit comments