Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
496 changes: 496 additions & 0 deletions examples/distributed_inference/llama3_model.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/distributed_inference/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ def parallel_rotary_block(rotary_block, tp_mesh):
"wk": ColwiseParallel(),
"wo": RowwiseParallel(output_layouts=Shard(0)),
}
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode
rotary_block.n_parallel = tp_mesh.size()

parallelize_module(rotary_block, tp_mesh, plan)


class RotaryAttention(nn.Module):
def __init__(self, dim: int, seq_len: int):
def __init__(self, dim: int, seq_len: int, n_parallel: int = 1):
super().__init__()
self.dim = dim
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wo = nn.Linear(dim, dim)
self.seq_len = seq_len
self.n_parallel = 1
self.n_parallel = n_parallel
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
self.init_weights()

Expand Down
31 changes: 3 additions & 28 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,8 @@
from torch.distributed._tensor.device_mesh import init_device_mesh


def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
# this is kept at the application level, when mpirun is used to run the application
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +29,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,13 +42,12 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
return device_mesh, world_size, rank


def cleanup_distributed_env():
Expand Down
72 changes: 72 additions & 0 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
import logging
import os
import time

import torch
import torch.distributed as dist
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_llama3")

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
dim=1024,
n_layers=4,
n_heads=8,
rope_theta=500000.0,
n_kv_heads=8,
device="cuda",
)

with torch.no_grad():
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
model = torch.compile(
model,
fullgraph=True,
backend="torch_tensorrt",
options={
"use_python_runtime": True,
"use_distributed_mode_trace": True,
"debug": True,
},
dynamic=False,
)

start = time.time()
output = model(inp)
end = time.time()
logger.info(f"Compilation time is {end-start}")
assert (python_result - output).std() < 0.01, "Compilation result is not correct."

cleanup_distributed_env()
19 changes: 14 additions & 5 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,30 @@
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
import torch.distributed as dist
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")

from rotary_embedding import RotaryAttention, parallel_rotary_block

"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with 2 GPUs: mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
"""

BATCH = 2
Expand All @@ -37,7 +46,7 @@
DIM = 128

with torch.no_grad():
model = RotaryAttention(DIM, SEQ_LEN)
model = RotaryAttention(DIM, SEQ_LEN, device_mesh.size())
parallel_rotary_block(model, device_mesh)
device = torch.device("cuda", device_mesh.get_rank())
model.to(device)
Expand Down
15 changes: 11 additions & 4 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,29 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

if not dist.is_initialized():
initialize_distributed_env()
import torch_tensorrt
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example")


"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
76 changes: 45 additions & 31 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,55 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path:
)


def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None:
# this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM
from torch.distributed import barrier, get_rank, is_initialized

if not is_initialized():
# Single process case, just unzip
is_master = True
else:
is_master = get_rank() == 0 # only rank 0 does the unzip

if is_master:
try:
import zipfile
except ImportError as e:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
try:
with zipfile.ZipFile(wheel_path) as zip_ref:
zip_ref.extractall(extract_dir)
logger.debug(f"Extracted wheel to {extract_dir}")

except FileNotFoundError as e:
# This should capture the errors in the download failure above
logger.error(f"Wheel file not found at {wheel_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {wheel_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e

# Make sure others wait until unzip is done
if is_initialized():
barrier()


def download_and_get_plugin_lib_path() -> Optional[str]:
"""
Returns the path to the TensorRT‑LLM shared library, downloading and extracting if necessary.

Args:
platform (str): Platform identifier (e.g., 'linux_x86_64')

Returns:
Optional[str]: Path to shared library or None if operation fails.
"""
Expand Down Expand Up @@ -194,32 +236,7 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
except OSError as e:
logger.error(f"Local file write error: {e}")

try:
import zipfile
except ImportError as e:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
try:
with zipfile.ZipFile(wheel_path) as zip_ref:
zip_ref.extractall(extract_dir)
logger.debug(f"Extracted wheel to {extract_dir}")
except FileNotFoundError as e:
# This should capture the errors in the download failure above
logger.error(f"Wheel file not found at {wheel_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {wheel_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e
extract_wheel_file(wheel_path, extract_dir)

try:
wheel_path.unlink(missing_ok=True)
Expand All @@ -238,10 +255,8 @@ def download_and_get_plugin_lib_path() -> Optional[str]:
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
"""
Loads and initializes the TensorRT-LLM plugin from the given shared library path.

Args:
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.

Returns:
bool: True if successful, False otherwise.
"""
Expand Down Expand Up @@ -293,7 +308,6 @@ def load_tensorrt_llm_for_nccl() -> bool:
Attempts to load the TensorRT-LLM plugin and initialize it.
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it

Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
Expand Down
9 changes: 1 addition & 8 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults, partitioning
from torch_tensorrt.dynamo._DryRunTracker import (
DryRunTracker,
Expand Down Expand Up @@ -287,7 +286,6 @@ def cross_compile_for_windows(
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
Expand Down Expand Up @@ -377,7 +375,6 @@ def cross_compile_for_windows(
)
trt_gm = compile_module(
gm,
trt_arg_inputs,
trt_kwarg_inputs,
settings,
)
Expand Down Expand Up @@ -623,7 +620,6 @@ def compile(
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
Expand Down Expand Up @@ -709,16 +705,13 @@ def compile(
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
trt_gm = compile_module(gm, trt_kwarg_inputs, settings, engine_cache)
return trt_gm


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
Expand Down
Loading
Loading