Skip to content

Commit 35dc907

Browse files
committed
[NFC] Switch to new pass generation tablegen definitions.
This commit completes the migration from the deprecated GEN_PASS_CLASSES to the new GEN_PASS_DEF infrastructure across all torch-mlir passes. Changes include: 1. Remove PassDetail.h files (deprecated pattern) - Deleted lib/Conversion/PassDetail.h - Deleted lib/RefBackend/PassDetail.h - Deleted lib/Dialect/Torch/Transforms/PassDetail.h - Deleted lib/Dialect/TorchConversion/Transforms/PassDetail.h - Deleted lib/Dialect/TMTensor/Transforms/PassDetail.h 2. Migrate conversion passes to GEN_PASS_DEF - Updated all passes in lib/Conversion/ to use #define GEN_PASS_DEF_* - Removed GEN_PASS_DECL from .cpp files (move to headers where needed) - Fixed includes and namespace declarations 3. Migrate dialect transform passes - Updated Torch, TorchConversion, and TMTensor transform passes - Properly scoped GEN_PASS_DEF in namespace blocks 4. Handle passes with options (TorchToStablehlo, TorchToTosa) - Added GEN_PASS_DECL_* to headers - Implemented default and convenience create functions - Used generated constructors via `using BaseClass::BaseClass` 5. Handle passes without options (RefBackend) - Removed manual create function implementations - Let tablegen auto-generate create functions - Added using declarations for Base classes in impl namespace 6. Fix backend type conversion passes - Added missing create functions in BackendTypeConversionPasses.cpp - Fixed namespace scoping issues 7. Fix missing namespace closures - Added proper closing namespace comments in Verify*BackendContract.cpp The migration maintains full backward compatibility while adopting the recommended LLVM pass infrastructure patterns. All passes now use the generated base classes and follow consistent patterns based on whether they have options defined in tablegen. Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 8d563af commit 35dc907

File tree

57 files changed

+557
-426
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+557
-426
lines changed

include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616

1717
namespace mlir {
1818
namespace torch {
19+
20+
#define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO
21+
#include "torch-mlir/Conversion/Passes.h.inc"
22+
1923
std::unique_ptr<OperationPass<func::FuncOp>>
2024
createConvertTorchToStablehloPass();
25+
26+
// Convenience wrapper for users who want to pass options as individual
27+
// parameters
2128
std::unique_ptr<OperationPass<func::FuncOp>>
2229
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
30+
2331
} // namespace torch
2432
} // namespace mlir
2533

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
namespace mlir {
2020
namespace torch {
2121

22+
#define GEN_PASS_DECL_CONVERTTORCHTOTOSA
23+
#include "torch-mlir/Conversion/Passes.h.inc"
24+
2225
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
2326
/// dialect.
2427
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
@@ -30,8 +33,12 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
3033
RewritePatternSet &patterns);
3134

3235
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
36+
37+
// Convenience wrapper for users who want to pass options as individual
38+
// parameters
3339
std::unique_ptr<OperationPass<func::FuncOp>>
3440
createConvertTorchToTosaPass(bool requireFullTosaConversion);
41+
3542
} // namespace torch
3643
} // namespace mlir
3744

include/torch-mlir/RefBackend/Passes.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,12 @@
1515
#include "mlir/Pass/PassManager.h"
1616

1717
namespace mlir {
18-
class ModuleOp;
19-
2018
namespace torch {
2119
namespace RefBackend {
2220

2321
/// Registers all RefBackend passes.
2422
void registerRefBackendPasses();
2523

26-
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
27-
28-
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
29-
30-
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
31-
32-
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
33-
34-
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();
35-
36-
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
3724
} // namespace RefBackend
3825
} // namespace torch
3926
} // namespace mlir

include/torch-mlir/RefBackend/Passes.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,29 @@ include "mlir/Pass/PassBase.td"
1414

1515
def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleOp"> {
1616
let summary = "Munge calling conventions for calling via ExecutionEngine";
17-
let constructor = "mlir::torch::RefBackend::createMungeCallingConventionsPass();";
1817
let dependentDialects = ["memref::MemRefDialect"];
1918
}
2019

2120
def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
2221
let summary = "Bufferize the MLProgram dialect ops";
23-
let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
2422
let dependentDialects = ["memref::MemRefDialect"];
2523
}
2624

2725
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
2826
let summary = "Expand ops into more primitive ops before LLVM lowering.";
29-
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
3027
}
3128

3229
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
3330
let summary = "Munge memref.copy to linalg.copy";
34-
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
3531
let dependentDialects = ["memref::MemRefDialect"];
3632
}
3733

3834
def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
3935
let summary = "Convert tensor.concat to other tensor ops";
40-
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
4136
}
4237

4338
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
4439
let summary = "Convert tensor.pad to linalg ops";
45-
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
4640
}
4741

4842
#endif // TORCHMLIR_REFBACKEND_PASSES

lib/Conversion/PassDetail.h

Lines changed: 0 additions & 26 deletions
This file was deleted.

lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10+
#include "torch-mlir/Conversion/Passes.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1013
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
1114

12-
#include "../PassDetail.h"
1315
#include "mlir/Dialect/Arith/IR/Arith.h"
1416
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1517
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -20,6 +22,12 @@ using namespace mlir;
2022
using namespace mlir::torch;
2123
using namespace mlir::torch::Torch;
2224
using namespace mlir::torch::TorchConversion;
25+
namespace mlir::torch {
26+
27+
#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM
28+
#include "torch-mlir/Conversion/Passes.h.inc"
29+
30+
2331

2432
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
2533

@@ -102,7 +110,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
102110

103111
namespace {
104112
class ConvertTorchConversionToMLProgram
105-
: public ConvertTorchConversionToMLProgramBase<
113+
: public impl::ConvertTorchConversionToMLProgramBase<
106114
ConvertTorchConversionToMLProgram> {
107115
public:
108116
void getDependentDialects(DialectRegistry &registry) const override {
@@ -138,6 +146,8 @@ class ConvertTorchConversionToMLProgram
138146
} // namespace
139147

140148
std::unique_ptr<OperationPass<ModuleOp>>
141-
mlir::torch::createConvertTorchConversionToMLProgramPass() {
149+
createConvertTorchConversionToMLProgramPass() {
142150
return std::make_unique<ConvertTorchConversionToMLProgram>();
143151
}
152+
153+
} // namespace mlir::torch

lib/Conversion/TorchOnnxToTorch/PassDetail.h

Lines changed: 0 additions & 24 deletions
This file was deleted.

lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10-
#include "./PassDetail.h"
10+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1113
#include "mlir/Support/LLVM.h"
1214
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
1315
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
@@ -19,6 +21,12 @@ using llvm::dbgs;
1921
using namespace mlir;
2022
using namespace mlir::torch;
2123
using namespace mlir::torch::onnx_c;
24+
namespace mlir::torch::onnx_c {
25+
26+
#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH
27+
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc"
28+
29+
2230

2331
#define DEBUG_TYPE "torch-onnx"
2432

@@ -37,7 +45,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) {
3745
}
3846

3947
class ConvertTorchOnnxToTorch
40-
: public ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
48+
: public impl::ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
4149
public:
4250
ConvertTorchOnnxToTorch() = default;
4351
void runOnOperation() override {
@@ -83,6 +91,8 @@ class ConvertTorchOnnxToTorch
8391
} // namespace
8492

8593
std::unique_ptr<OperationPass<func::FuncOp>>
86-
mlir::torch::onnx_c::createTorchOnnxToTorchPass() {
94+
createTorchOnnxToTorchPass() {
8795
return std::make_unique<ConvertTorchOnnxToTorch>();
8896
}
97+
98+
} // namespace mlir::torch::onnx_c

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10+
#include "torch-mlir/Conversion/Passes.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1013
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
1114

12-
#include "../PassDetail.h"
1315
#include "mlir/Dialect/Arith/IR/Arith.h"
1416
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1517
#include "mlir/Dialect/Math/IR/Math.h"
@@ -25,6 +27,12 @@
2527
using namespace mlir;
2628
using namespace mlir::torch;
2729
using namespace mlir::torch::Torch;
30+
namespace mlir::torch {
31+
32+
#define GEN_PASS_DEF_CONVERTTORCHTOARITH
33+
#include "torch-mlir/Conversion/Passes.h.inc"
34+
35+
2836

2937
// -----------------------------------------------------------------------------
3038
// Patterns (as this grows, it should be organized into multiple files)
@@ -407,7 +415,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern<OpTy> {
407415

408416
namespace {
409417
class ConvertTorchToArith
410-
: public ConvertTorchToArithBase<ConvertTorchToArith> {
418+
: public impl::ConvertTorchToArithBase<ConvertTorchToArith> {
411419
public:
412420
void getDependentDialects(DialectRegistry &registry) const override {
413421
registry.insert<func::FuncDialect>();
@@ -566,6 +574,8 @@ class ConvertTorchToArith
566574
} // namespace
567575

568576
std::unique_ptr<OperationPass<func::FuncOp>>
569-
mlir::torch::createConvertTorchToArithPass() {
577+
createConvertTorchToArithPass() {
570578
return std::make_unique<ConvertTorchToArith>();
571579
}
580+
581+
} // namespace mlir::torch

0 commit comments

Comments
 (0)