Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/cuda_std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
#![allow(internal_features)]
#![cfg_attr(
target_os = "cuda",
feature(alloc_error_handler, asm_experimental_arch, link_llvm_intrinsics)
feature(
alloc_error_handler,
asm_experimental_arch,
link_llvm_intrinsics,
stdarch_nvptx
)
)]

extern crate alloc;
Expand Down
111 changes: 36 additions & 75 deletions crates/cuda_std/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,6 @@ use glam::{UVec2, UVec3};
// different calling conventions dont exist in nvptx, so we just use C as a placeholder.
extern "C" {
// defined in libintrinsics.ll
fn __nvvm_thread_idx_x() -> u32;
fn __nvvm_thread_idx_y() -> u32;
fn __nvvm_thread_idx_z() -> u32;

fn __nvvm_block_dim_x() -> u32;
fn __nvvm_block_dim_y() -> u32;
fn __nvvm_block_dim_z() -> u32;

fn __nvvm_block_idx_x() -> u32;
fn __nvvm_block_idx_y() -> u32;
fn __nvvm_block_idx_z() -> u32;

fn __nvvm_grid_dim_x() -> u32;
fn __nvvm_grid_dim_y() -> u32;
fn __nvvm_grid_dim_z() -> u32;

fn __nvvm_warp_size() -> u32;

fn __nvvm_block_barrier();
Expand All @@ -89,26 +73,15 @@ extern "C" {
}

#[cfg(target_os = "cuda")]
macro_rules! inbounds {
// the bounds were taken mostly from the cuda C++ programming guide, i also
// double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata
($func_name:ident, $bound:expr) => {{
let val = unsafe { $func_name() };
if val > $bound {
// SAFETY: this condition is declared unreachable by compute capability max bound
macro_rules! in_range {
// The bounds were taken mostly from the cuda C++ programming guide. I also
// double-checked with what cuda clang does by checking its emitted llvm ir's scalar metadata.
($func_name:path, $range:expr) => {{
let val = unsafe { $func_name() as u32 };
if !$range.contains(&val) {
// SAFETY: this condition is declared unreachable by compute capability max bound.
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
// we do this to potentially allow for better optimizations by LLVM
unsafe { core::hint::unreachable_unchecked() }
} else {
val
}
}};
($func_name:ident, $lower_bound:expr, $upper_bound:expr) => {{
let val = unsafe { $func_name() };
if !($lower_bound..=$upper_bound).contains(&val) {
// SAFETY: this condition is declared unreachable by compute capability max bound
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
// we do this to potentially allow for better optimizations by LLVM
// We do this to potentially allow for better optimizations by LLVM.
unsafe { core::hint::unreachable_unchecked() }
} else {
val
Expand All @@ -119,127 +92,115 @@ macro_rules! inbounds {
#[gpu_only]
#[inline(always)]
pub fn thread_idx_x() -> u32 {
inbounds!(__nvvm_thread_idx_x, 1024)
// The range is derived from the `block_idx_x` range.
in_range!(core::arch::nvptx::_thread_idx_x, 0..1024)
}

#[gpu_only]
#[inline(always)]
pub fn thread_idx_y() -> u32 {
inbounds!(__nvvm_thread_idx_y, 1024)
// The range is derived from the `block_idx_y` range.
in_range!(core::arch::nvptx::_thread_idx_y, 0..1024)
}

#[gpu_only]
#[inline(always)]
pub fn thread_idx_z() -> u32 {
inbounds!(__nvvm_thread_idx_z, 64)
// The range is derived from the `block_idx_z` range.
in_range!(core::arch::nvptx::_thread_idx_z, 0..64)
}

#[gpu_only]
#[inline(always)]
pub fn block_idx_x() -> u32 {
inbounds!(__nvvm_block_idx_x, 2147483647)
// The range is derived from the `grid_idx_x` range.
in_range!(core::arch::nvptx::_block_idx_x, 0..2147483647)
}

#[gpu_only]
#[inline(always)]
pub fn block_idx_y() -> u32 {
inbounds!(__nvvm_block_idx_y, 65535)
// The range is derived from the `grid_idx_y` range.
in_range!(core::arch::nvptx::_block_idx_y, 0..65535)
}

#[gpu_only]
#[inline(always)]
pub fn block_idx_z() -> u32 {
inbounds!(__nvvm_block_idx_z, 65535)
// The range is derived from the `grid_idx_z` range.
in_range!(core::arch::nvptx::_block_idx_z, 0..65535)
}

#[gpu_only]
#[inline(always)]
pub fn block_dim_x() -> u32 {
inbounds!(__nvvm_block_dim_x, 1, 1025)
// CUDA Compute Capabilities: "Maximum x- or y-dimensionality of a block" is 1024.
in_range!(core::arch::nvptx::_block_dim_x, 1..=1024)
}

#[gpu_only]
#[inline(always)]
pub fn block_dim_y() -> u32 {
inbounds!(__nvvm_block_dim_y, 1, 1025)
// CUDA Compute Capabilities: "Maximum x- or y-dimensionality of a block" is 1024.
in_range!(core::arch::nvptx::_block_dim_y, 1..=1024)
}

#[gpu_only]
#[inline(always)]
pub fn block_dim_z() -> u32 {
inbounds!(__nvvm_block_dim_z, 1, 65)
// CUDA Compute Capabilities: "Maximum z-dimension of a block" is 64.
in_range!(core::arch::nvptx::_block_dim_z, 1..=64)
}

#[gpu_only]
#[inline(always)]
pub fn grid_dim_x() -> u32 {
inbounds!(__nvvm_grid_dim_x, 1, 2147483648)
// CUDA Compute Capabilities: "Maximum x-dimension of a grid of thread blocks" is 2^32 - 1.
in_range!(core::arch::nvptx::_grid_dim_x, 1..=2147483647)
}

#[gpu_only]
#[inline(always)]
pub fn grid_dim_y() -> u32 {
inbounds!(__nvvm_grid_dim_y, 1, 65536)
// CUDA Compute Capabilities: "Maximum y- or z-dimension of a grid of thread blocks" is 65535.
in_range!(core::arch::nvptx::_grid_dim_y, 1..=65535)
}

#[gpu_only]
#[inline(always)]
pub fn grid_dim_z() -> u32 {
inbounds!(__nvvm_grid_dim_z, 1, 65536)
// CUDA Compute Capabilities: "Maximum y- or z-dimension of a grid of thread blocks" is 65535.
in_range!(core::arch::nvptx::_grid_dim_z, 1..=65535)
}

/// Gets the 3d index of the thread currently executing the kernel.
#[gpu_only]
#[inline(always)]
pub fn thread_idx() -> UVec3 {
unsafe {
UVec3::new(
__nvvm_thread_idx_x(),
__nvvm_thread_idx_y(),
__nvvm_thread_idx_z(),
)
}
UVec3::new(thread_idx_x(), thread_idx_y(), thread_idx_z())
}

/// Gets the 3d index of the block that the thread currently executing the kernel is located in.
#[gpu_only]
#[inline(always)]
pub fn block_idx() -> UVec3 {
unsafe {
UVec3::new(
__nvvm_block_idx_x(),
__nvvm_block_idx_y(),
__nvvm_block_idx_z(),
)
}
UVec3::new(block_idx_x(), block_idx_y(), block_idx_z())
}

/// Gets the 3d layout of the thread blocks executing this kernel. In other words,
/// how many threads exist in each thread block in every direction.
#[gpu_only]
#[inline(always)]
pub fn block_dim() -> UVec3 {
unsafe {
UVec3::new(
__nvvm_block_dim_x(),
__nvvm_block_dim_y(),
__nvvm_block_dim_z(),
)
}
UVec3::new(block_dim_x(), block_dim_y(), block_dim_z())
}

/// Gets the 3d layout of the block grids executing this kernel. In other words,
/// how many thread blocks exist in each grid in every direction.
#[gpu_only]
#[inline(always)]
pub fn grid_dim() -> UVec3 {
unsafe {
UVec3::new(
__nvvm_grid_dim_x(),
__nvvm_grid_dim_y(),
__nvvm_grid_dim_z(),
)
}
UVec3::new(grid_dim_x(), grid_dim_y(), grid_dim_z())
}

/// Gets the overall thread index, accounting for 1d/2d/3d block/grid dimensions. This
Expand Down
Binary file modified crates/rustc_codegen_nvvm/libintrinsics.bc
Binary file not shown.
92 changes: 0 additions & 92 deletions crates/rustc_codegen_nvvm/libintrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,86 +8,6 @@ source_filename = "libintrinsics"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; thread ----

define i32 @__nvvm_thread_idx_x() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
ret i32 %0
}

define i32 @__nvvm_thread_idx_y() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
ret i32 %0
}

define i32 @__nvvm_thread_idx_z() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
ret i32 %0
}

; block dimension ----

define i32 @__nvvm_block_dim_x() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
ret i32 %0
}

define i32 @__nvvm_block_dim_y() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
ret i32 %0
}

define i32 @__nvvm_block_dim_z() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
ret i32 %0
}

; block idx ----

define i32 @__nvvm_block_idx_x() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
ret i32 %0
}

define i32 @__nvvm_block_idx_y() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
ret i32 %0
}

define i32 @__nvvm_block_idx_z() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
ret i32 %0
}

; grid dimension ----

define i32 @__nvvm_grid_dim_x() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
ret i32 %0
}

define i32 @__nvvm_grid_dim_y() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
ret i32 %0
}

define i32 @__nvvm_grid_dim_z() #0 {
start:
%0 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
ret i32 %0
}

; warp ----

define i32 @__nvvm_warp_size() #0 {
Expand All @@ -96,18 +16,6 @@ start:
ret i32 %0
}

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.tid.y()
declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()

; other ----
Expand Down
Loading