Skip to content

Commit 0df823d

Browse files
authored
fix for export refactor (#189)
stack-info: PR: #189, branch: xmfan/stack/12
1 parent 526849d commit 0df823d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

autoparallel/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,6 @@ def build_model_graph(self):
288288
ep.module(),
289289
inputs,
290290
decompositions=decomp_table,
291-
fw_compiler=self.compiler_fn,
292-
bw_compiler=self.compiler_fn,
293291
)
294292
gm = self.joint_with_descriptors.graph_module
295293
assert_has_no_collectives(gm)
@@ -454,7 +452,9 @@ def apply_placement(self, sharding_placement=None):
454452
)
455453

456454
self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
457-
self.joint_with_descriptors
455+
self.joint_with_descriptors,
456+
fw_compiler=self.compiler_fn,
457+
bw_compiler=self.compiler_fn,
458458
)
459459

460460
# TODO: this probably belongs in the AOTAutograd API

0 commit comments

Comments
 (0)