-
Notifications
You must be signed in to change notification settings - Fork 7
[ingress][torch-mlir][RFC] Initial version of fx-importer script using torch-mlir #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
1a4b755
1c6df47
370e3c0
b984314
2897e0c
fa7d1de
eaf8c9e
889314d
0a188e5
11240bd
1068bf9
aafffef
0912de1
d5c710c
029b8ad
2215980
8bec558
a1a1ddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| """ | ||
| Example demonstrating how to load a PyTorch model to MLIR using Lighthouse | ||
| without instantiating the model on the user's side. | ||
|
|
||
| The script uses 'lighthouse.ingress.torch.import_from_file' function that | ||
| takes a path to a Python file containing the model definition, along with | ||
| the names of functions to get model init arguments and sample inputs. The function | ||
| imports the model class on its own, instantiates it, and passes it to torch_mlir | ||
| to get a MLIR module in the specified dialect. | ||
|
|
||
| The script uses the model from 'DummyMLP/model.py' as an example. | ||
| """ | ||
|
|
||
| import os | ||
| from pathlib import Path | ||
|
|
||
| # MLIR infrastructure imports (only needed if you want to manipulate the MLIR module) | ||
| import mlir.dialects.func as func | ||
| from mlir import ir, passmanager | ||
|
|
||
| # Lighthouse imports | ||
| from lighthouse.ingress.torch import import_from_file | ||
|
|
||
| # Step 1: Set up paths to locate the model definition file | ||
| script_dir = Path(os.path.dirname(os.path.abspath(__file__))) | ||
| model_path = script_dir / "DummyMLP" / "model.py" | ||
|
|
||
| ir_context = ir.Context() | ||
|
|
||
| # Step 2: Convert PyTorch model to MLIR | ||
| # Conversion step where Lighthouse: | ||
| # - Loads the DummyMLP class and instantiates it with arguments obtained from 'get_init_inputs()' | ||
| # - Calls get_sample_inputs() to get sample input tensors for shape inference | ||
| # - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir | ||
| mlir_module_ir: ir.Module = import_from_file( | ||
| model_path, # Path to the Python file containing the model | ||
| model_class_name="DummyMLP", # Name of the PyTorch nn.Module class to convert | ||
| init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__() | ||
| sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)' | ||
| dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types) | ||
| ir_context=ir_context # MLIR context for the conversion | ||
| ) | ||
|
|
||
| # The PyTorch model is now converted to MLIR at this point. You can now convert | ||
| # the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. | ||
| # | ||
| # The following optional MLIR-processing steps are to give you an idea of what can | ||
| # also be done with the MLIR module. | ||
|
|
||
| # Step 3: Extract the main function operation from the MLIR module and print its metadata | ||
| func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0] | ||
| print(f"entry-point name: {func_op.name}") | ||
| print(f"entry-point type: {func_op.type}") | ||
|
|
||
| # Step 4: Apply some MLIR passes using a PassManager | ||
| pm = passmanager.PassManager(context=ir_context) | ||
| pm.add("linalg-specialize-generic-ops") | ||
| pm.add("one-shot-bufferize") | ||
| pm.run(mlir_module_ir.operation) | ||
|
|
||
| # Step 5: Output the final MLIR | ||
| print("\n\nModule dump after running the pipeline:") | ||
| mlir_module_ir.dump() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| """ | ||
| Example demonstrating how to load an already instantiated PyTorch model | ||
| to MLIR using Lighthouse. | ||
|
||
|
|
||
| The script uses the 'lighthouse.ingress.torch.import_from_model' function that | ||
| takes a PyTorch model that has already been instantiated, along with its sample inputs. | ||
| The function passes the model to torch_mlir to get a MLIR module in the | ||
| specified dialect. | ||
|
|
||
| The script uses a model from 'DummyMLP/model.py' as an example. | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| # MLIR infrastructure imports (only needed if you want to manipulate the MLIR module) | ||
| import mlir.dialects.func as func | ||
| from mlir import ir, passmanager | ||
|
|
||
| # Lighthouse imports | ||
| from lighthouse.ingress.torch import import_from_model | ||
|
|
||
| # Import a sample model definition | ||
| from DummyMLP.model import DummyMLP | ||
|
|
||
| # Step 1: Instantiate a model and prepare sample input | ||
| model = DummyMLP() | ||
| sample_input = torch.randn(1, 10) | ||
|
|
||
| ir_context = ir.Context() | ||
| # Step 2: Convert the PyTorch model to MLIR | ||
| mlir_module_ir: ir.Module = import_from_model( | ||
| model, | ||
| sample_args=(sample_input,), | ||
| ir_context=ir_context | ||
| ) | ||
|
|
||
| # The PyTorch model is now converted to MLIR at this point. You can now convert | ||
| # the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. | ||
| # | ||
| # The following optional MLIR-processing steps are to give you an idea of what can | ||
| # also be done with the MLIR module. | ||
|
|
||
| # Step 3: Extract the main function operation from the MLIR module and print its metadata | ||
| func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0] | ||
| print(f"entry-point name: {func_op.name}") | ||
| print(f"entry-point type: {func_op.type}") | ||
|
|
||
| # Step 4: Apply some MLIR passes using a PassManager | ||
| pm = passmanager.PassManager(context=ir_context) | ||
| pm.add("linalg-specialize-generic-ops") | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pm.add("one-shot-bufferize") | ||
|
||
| pm.run(mlir_module_ir.operation) | ||
|
|
||
| # Step 5: Output the final MLIR | ||
| print("\n\nModule dump after running the pipeline:") | ||
| mlir_module_ir.dump() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| """Defines a simple PyTorch model to be used in lighthouse's ingress examples.""" | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| import os | ||
|
|
||
| class DummyMLP(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.net = nn.Sequential( | ||
| nn.Linear(10, 32), | ||
| nn.ReLU(), | ||
| nn.Linear(32, 2) | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.net(x) | ||
|
|
||
|
|
||
| def get_init_inputs(): | ||
| """Function to return args to pass to DummyMLP.__init__()""" | ||
| return () | ||
|
|
||
|
|
||
| def get_sample_inputs(): | ||
| """Arguments to pass to DummyMLP.forward()""" | ||
| return (torch.randn(1, 10),) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| torch.save(DummyMLP().state_dict(), os.path.join(script_dir, "dummy_mlp.pth")) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Lighthouse Ingress | ||
|
|
||
| The `lighthouse.ingress` module converts various input formats to MLIR modules. | ||
|
|
||
| ## Supported Formats | ||
|
|
||
| #### Torch | ||
| Converts PyTorch models to MLIR using `lighthouse.ingress.torch`. | ||
|
|
||
| **Examples:** [torch examples](https://github.com/llvm/lighthouse/tree/main/python/examples/ingress/torch) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .torch_import import import_from_file, import_from_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect we will want this (well,
linalg-specialize-generic-opsat least) to graduate to on-by-default transforms that happen on most importing/conversion interactions withlighthouse.ingress.torch(or evenlighthouse.ingress). Not a must for now though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going to category could be a more useful or more broadly applicable default but TBD