We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 526849d commit 0df823dCopy full SHA for 0df823d
autoparallel/api.py
@@ -288,8 +288,6 @@ def build_model_graph(self):
288
ep.module(),
289
inputs,
290
decompositions=decomp_table,
291
- fw_compiler=self.compiler_fn,
292
- bw_compiler=self.compiler_fn,
293
)
294
gm = self.joint_with_descriptors.graph_module
295
assert_has_no_collectives(gm)
@@ -454,7 +452,9 @@ def apply_placement(self, sharding_placement=None):
454
452
455
453
456
self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
457
- self.joint_with_descriptors
+ self.joint_with_descriptors,
+ fw_compiler=self.compiler_fn,
+ bw_compiler=self.compiler_fn,
458
459
460
# TODO: this probably belongs in the AOTAutograd API
0 commit comments