Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@

namespace mlir {
namespace torch {

#define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO
#include "torch-mlir/Conversion/Passes.h.inc"

std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass();

// Convenience wrapper for users who want to pass options as individual
// parameters
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);

} // namespace torch
} // namespace mlir

Expand Down
7 changes: 7 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
namespace mlir {
namespace torch {

#define GEN_PASS_DECL_CONVERTTORCHTOTOSA
#include "torch-mlir/Conversion/Passes.h.inc"

/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
/// dialect.
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
Expand All @@ -30,8 +33,12 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
RewritePatternSet &patterns);

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();

// Convenience wrapper for users who want to pass options as individual
// parameters
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToTosaPass(bool requireFullTosaConversion);

} // namespace torch
} // namespace mlir

Expand Down
13 changes: 0 additions & 13 deletions include/torch-mlir/RefBackend/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,12 @@
#include "mlir/Pass/PassManager.h"

namespace mlir {
class ModuleOp;

namespace torch {
namespace RefBackend {

/// Registers all RefBackend passes.
void registerRefBackendPasses();

std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();

std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();

std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();

std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();

std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorConcatPass();

std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
Comment on lines -26 to -36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these removed because they are unused elsewhere in this project? E.g. we still have createTorchToLinalgPass, but is that only because it is used in some pass pipeline?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: this was addressed in your PR description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think they are handled by tablegen now. They were only used because we have let constructor = in tablegen, and the wrapper is not needed at all. I searched one of the pass creation on github, and there are no users at all: https://github.com/search?q=createExpandOpsForLLVM&type=code&p=1

I'd expect people to use the default (e.g., createGeneralizeTensorPad) to add a pass. To me, the pass name in td file should end up with Pass, which matches other MLIR project convention. E.g., def GeneralizeTensorPad -> def GeneralizeTensorPadPass.

} // namespace RefBackend
} // namespace torch
} // namespace mlir
Expand Down
6 changes: 0 additions & 6 deletions include/torch-mlir/RefBackend/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,29 @@ include "mlir/Pass/PassBase.td"

def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleOp"> {
let summary = "Munge calling conventions for calling via ExecutionEngine";
let constructor = "mlir::torch::RefBackend::createMungeCallingConventionsPass();";
let dependentDialects = ["memref::MemRefDialect"];
}

def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> {
let summary = "Bufferize the MLProgram dialect ops";
let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();";
let dependentDialects = ["memref::MemRefDialect"];
}

def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
let summary = "Expand ops into more primitive ops before LLVM lowering.";
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
}

def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
let summary = "Munge memref.copy to linalg.copy";
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
let dependentDialects = ["memref::MemRefDialect"];
}

def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> {
let summary = "Convert tensor.concat to other tensor ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()";
}

def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
}

#endif // TORCHMLIR_REFBACKEND_PASSES
26 changes: 0 additions & 26 deletions lib/Conversion/PassDetail.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@

#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM
#include "torch-mlir/Conversion/Passes.h.inc"

static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }

Expand Down Expand Up @@ -102,7 +108,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {

namespace {
class ConvertTorchConversionToMLProgram
: public ConvertTorchConversionToMLProgramBase<
: public impl::ConvertTorchConversionToMLProgramBase<
ConvertTorchConversionToMLProgram> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -138,6 +144,8 @@ class ConvertTorchConversionToMLProgram
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::createConvertTorchConversionToMLProgramPass() {
createConvertTorchConversionToMLProgramPass() {
return std::make_unique<ConvertTorchConversionToMLProgram>();
}

} // namespace mlir::torch
24 changes: 0 additions & 24 deletions lib/Conversion/TorchOnnxToTorch/PassDetail.h

This file was deleted.

14 changes: 10 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
//
//===----------------------------------------------------------------------===//

#include "./PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
Expand All @@ -19,6 +20,10 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;
namespace mlir::torch::onnx_c {

#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc"

#define DEBUG_TYPE "torch-onnx"

Expand All @@ -37,7 +42,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) {
}

class ConvertTorchOnnxToTorch
: public ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
: public impl::ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
public:
ConvertTorchOnnxToTorch() = default;
void runOnOperation() override {
Expand Down Expand Up @@ -82,7 +87,8 @@ class ConvertTorchOnnxToTorch

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::onnx_c::createTorchOnnxToTorchPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass() {
return std::make_unique<ConvertTorchOnnxToTorch>();
}

} // namespace mlir::torch::onnx_c
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand All @@ -25,6 +27,10 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOARITH
#include "torch-mlir/Conversion/Passes.h.inc"

// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
Expand Down Expand Up @@ -407,7 +413,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern<OpTy> {

namespace {
class ConvertTorchToArith
: public ConvertTorchToArithBase<ConvertTorchToArith> {
: public impl::ConvertTorchToArithBase<ConvertTorchToArith> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
Expand Down Expand Up @@ -565,7 +571,8 @@ class ConvertTorchToArith
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToArithPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass() {
return std::make_unique<ConvertTorchToArith>();
}

} // namespace mlir::torch
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,27 @@

#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOLINALG
#include "torch-mlir/Conversion/Passes.h.inc"

// -----------------------------------------------------------------------------
// The pass
Expand All @@ -34,7 +40,7 @@ using namespace mlir::torch::Torch;

namespace {
class ConvertTorchToLinalg
: public ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
: public impl::ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
Expand Down Expand Up @@ -89,7 +95,8 @@ class ConvertTorchToLinalg
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToLinalgPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass() {
return std::make_unique<ConvertTorchToLinalg>();
}

} // namespace mlir::torch
1 change: 0 additions & 1 deletion lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
Expand Down
Loading
Loading