Skip to content

Commit f51758c

Browse files
[None][fix] Correct virtual memory allocation alignment
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
1 parent b10137f commit f51758c

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

cpp/include/tensorrt_llm/runtime/virtualMemory.h

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cuda.h>
2626
#include <map>
2727
#include <mutex>
28+
#include <numeric>
2829
#include <unistd.h>
2930
#include <utility>
3031

@@ -214,7 +215,7 @@ struct LocalCreator : CUDAVirtualMemoryChunk::Creator
214215
CUmemGenericAllocationHandle create() override
215216
{
216217
CUmemGenericAllocationHandle handle{};
217-
TLLM_CU_CHECK(cuMemCreate(&handle, mSize, &mProp, 0));
218+
TLLM_CU_CHECK_WITH_INFO(cuMemCreate(&handle, mSize, &mProp, 0), "allocating %zu bytes of memory", mSize);
218219
if constexpr (count)
219220
{
220221
MemoryCounters::getInstance().allocate(
@@ -466,7 +467,7 @@ class CudaVirtualMemoryAllocator
466467
CudaVirtualMemoryManager& mManager;
467468
std::string mTag;
468469
CudaStreamPtr mBackStream;
469-
std::size_t mPageSize;
470+
std::atomic<std::size_t> mAlignment;
470471
RestoreMode mMode;
471472
bool mBackground{};
472473

@@ -487,14 +488,44 @@ class CudaVirtualMemoryAllocator
487488
: mManager(manager)
488489
, mTag(std::move(tag))
489490
, mBackStream(std::move(backStream))
490-
, mPageSize(getpagesize())
491+
, mAlignment(0)
491492
, mMode(mode)
492493
{
493494
}
494495

495-
[[nodiscard]] std::size_t pageAligned(std::size_t n) const noexcept
496+
[[nodiscard]] std::size_t pageAligned(std::size_t n, int device = 0)
496497
{
497-
return (n + mPageSize - 1) & ~(mPageSize - 1);
498+
// Lazy loading the alignment, since CUDA driver may yet to be initialized when Configuration is
499+
// constructed.
500+
constexpr std::size_t loading = std::numeric_limits<std::size_t>::max();
501+
std::size_t alignment = 0;
502+
if (mAlignment.compare_exchange_strong(alignment, loading, std::memory_order_relaxed))
503+
{
504+
std::size_t gpuAlignment = 1;
505+
CUmemAllocationProp const prop{CU_MEM_ALLOCATION_TYPE_PINNED, CU_MEM_HANDLE_TYPE_NONE,
506+
{
507+
CU_MEM_LOCATION_TYPE_DEVICE,
508+
device,
509+
}};
510+
TLLM_CU_CHECK(
511+
cuMemGetAllocationGranularity(&gpuAlignment, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
512+
alignment = std::lcm(getpagesize(), gpuAlignment);
513+
mAlignment.store(alignment, std::memory_order_relaxed);
514+
}
515+
else
516+
{
517+
// spin wait
518+
while (alignment == loading)
519+
{
520+
#if defined(__x86_64__)
521+
asm volatile("pause");
522+
#elif defined(__aarch64__)
523+
asm volatile("yield");
524+
#endif
525+
alignment = mAlignment.load(std::memory_order_relaxed);
526+
}
527+
}
528+
return (n + alignment - 1) / alignment * alignment;
498529
}
499530

500531
// Background configuration, used to indicate no virtual memory allocator is explicitly configured by the user.

cpp/tensorrt_llm/common/cudaDriverWrapper.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,22 @@ class CUDADriverWrapper
141141
};
142142

143143
template <typename T>
144-
void checkDriver(
145-
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
144+
void checkDriver(T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file,
145+
int const line, char const* info = nullptr)
146146
{
147147
if (result)
148148
{
149149
char const* errorName = nullptr;
150150
char const* errorString = nullptr;
151151
wrap.cuGetErrorName(result, &errorName);
152152
wrap.cuGetErrorString(result, &errorString);
153+
if (info != nullptr)
154+
{
155+
throw TllmException(file, line,
156+
fmtstr(
157+
"[TensorRT-LLM][ERROR] CUDA driver error in %s (%s): %s: %s.", func, info, errorName, errorString)
158+
.c_str());
159+
}
153160
throw TllmException(file, line,
154161
fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s.", func, errorName, errorString).c_str());
155162
}
@@ -177,6 +184,13 @@ void checkDriverExitSafe(T result, char const* const func, char const* const fil
177184
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
178185
} while (0)
179186

187+
#define TLLM_CU_CHECK_WITH_INFO(stat, info, ...) \
188+
do \
189+
{ \
190+
tensorrt_llm::common::checkDriver((stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, \
191+
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__).c_str()); \
192+
} while (0)
193+
180194
// Avoid using CUDADriverWrapper when freeing resource, during which the global instance may already be freed.
181195
#define TLLM_CU_CHECK_FREE_RESOURCE(stat) \
182196
do \

cpp/tensorrt_llm/runtime/virtualMemory.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,12 @@ static void* deviceptr_cast(CUdeviceptr ptr)
339339
void CudaVirtualMemoryAllocator::allocate(Pointer* ptr, std::size_t n, int device) const
340340
{
341341
CUdeviceptr address{};
342-
std::size_t const pageAlignedSize = mConfig->pageAligned(n);
343-
TLLM_CU_CHECK(cuMemAddressReserve(&address, pageAlignedSize, 0, {}, 0));
342+
std::size_t const pageAlignedSize = mConfig->pageAligned(n, device);
343+
TLLM_CU_CHECK_WITH_INFO(cuMemAddressReserve(&address, pageAlignedSize, 0, {}, 0),
344+
"allocating %zu bytes of address space", pageAlignedSize);
344345

345346
CUDAVirtualMemoryChunk::Configurators configurators;
346-
configurators.push_back(std::make_unique<UnicastConfigurator>(address, n,
347+
configurators.push_back(std::make_unique<UnicastConfigurator>(address, pageAlignedSize,
347348
CUmemAccessDesc{{
348349
CU_MEM_LOCATION_TYPE_DEVICE,
349350
device,
@@ -372,7 +373,7 @@ void CudaVirtualMemoryAllocator::allocate(Pointer* ptr, std::size_t n, int devic
372373
CU_MEM_LOCATION_TYPE_DEVICE,
373374
device,
374375
}},
375-
n),
376+
pageAlignedSize),
376377
std::move(configurators));
377378

378379
*ptr = deviceptr_cast(address);

0 commit comments

Comments
 (0)