Commit f609a73
Fix inference mode compilation in update_joint_with_descriptors (#265)
Fixed bug where updated_flat_args was incorrectly formatted as a tuple
for inference mode, causing AOT compilation to fail. The format should
be a list for inference mode and tuple (primals, tangents) for autograd
mode, per PyTorch's AOTGraphCapture schema.
ref: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/schemas.py#L1197
Testing: New unit test test_inference_mode_compilation()
Co-authored-by: Aditya Venkataraman <avenkataraman@fb.com>1 parent 449e35e commit f609a73
2 files changed
+64
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
65 | 65 | | |
66 | 66 | | |
67 | 67 | | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | | - | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
72 | 79 | | |
73 | 80 | | |
74 | 81 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
305 | 305 | | |
306 | 306 | | |
307 | 307 | | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
0 commit comments