Skip to content

Commit ffacd85

Browse files
nnethercoteLegNeato
authored andcommitted
Fix NvvmArch::all_target_features bugs.
It now does a single filter pass over the enum variants, which is simpler and fixes the sorting issue and the incorrect 'f' suffix results. I removed some comments in the `nvvm_arch_all_target_features` test, because they were low-value. There are now better comments within `all_target_features` that explain what's happening. I also remove the comment about PTX forward-compatibility. It was correct but confusing. This function answers the question "what features are available if I'm targeting a particular NvvmArch?" (backwards compatibility). That comment explained "what GPU CCs will this run on?" (forward compatibility). Also update the relevant section in the guide, where the 'f' details were incorrect. And make the terminology more consistent.
1 parent 2635863 commit ffacd85

File tree

2 files changed

+134
-99
lines changed

2 files changed

+134
-99
lines changed

crates/nvvm/src/lib.rs

Lines changed: 121 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -433,73 +433,62 @@ impl NvvmArch {
433433
}
434434
}
435435

436-
/// Get all target features up to and including this architecture.
436+
/// Gets all target features up to and including this architecture. This effectively answers
437+
/// the question "for a given compilation target, what architectural features can be used?"
437438
///
438-
/// # PTX Forward-Compatibility Rules (per NVIDIA documentation):
439+
/// # Examples
439440
///
440-
/// - **No suffix** (compute_XX): PTX is forward-compatible across all future architectures.
441-
/// Example: compute_70 runs on CC 7.0, 8.x, 9.x, 10.x, 12.x, and all future GPUs.
441+
/// ```
442+
/// # use nvvm::NvvmArch;
443+
/// let features = NvvmArch::Compute53.all_target_features();
444+
/// assert_eq!(
445+
/// features,
446+
/// vec!["compute_35", "compute_37", "compute_50", "compute_52", "compute_53"]
447+
/// );
448+
/// ```
442449
///
443-
/// - **Family-specific 'f' suffix** (compute_XXf): Forward-compatible within the same major
444-
/// version family. Supports devices with same major CC and equal or higher minor CC.
445-
/// Example: compute_100f runs on CC 10.0, 10.3, and future 10.x devices, but NOT on 11.x.
446-
///
447-
/// - **Architecture-specific 'a' suffix** (compute_XXa): The code only runs on GPUs of that
448-
/// specific CC and no others. No forward or backward compatibility whatsoever.
449-
/// These features are primarily related to Tensor Core programming.
450-
/// Example: compute_100a ONLY runs on CC 10.0, not on 10.3, 10.1, 9.0, or any other version.
450+
/// # External resources
451451
///
452452
/// For more details on family and architecture-specific features, see:
453453
/// <https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/>
454454
pub fn all_target_features(&self) -> Vec<String> {
455-
let mut features: Vec<String> = if self.is_architecture_variant() {
456-
// 'a' variants: include all available instructions for the architecture
457-
// This means: all base variants up to same version, all 'f' variants with same major and <= minor, plus itself
458-
let base_features: Vec<String> = NvvmArch::iter()
459-
.filter(|arch| {
460-
arch.is_base_variant() && arch.capability_value() <= self.capability_value()
461-
})
462-
.map(|arch| arch.target_feature())
463-
.collect();
464-
465-
let family_features: Vec<String> = NvvmArch::iter()
466-
.filter(|arch| {
467-
arch.is_family_variant()
468-
&& arch.major_version() == self.major_version()
469-
&& arch.minor_version() <= self.minor_version()
470-
})
471-
.map(|arch| arch.target_feature())
472-
.collect();
455+
// All lower-or-equal baseline features are included.
456+
let included_baseline = |arch: &NvvmArch| {
457+
arch.is_base_variant() && arch.capability_value() <= self.capability_value()
458+
};
473459

474-
base_features
475-
.into_iter()
476-
.chain(family_features)
477-
.chain(std::iter::once(self.target_feature()))
460+
// All lower-or-equal-with-same-major-version family features are included.
461+
let included_family = |arch: &NvvmArch| {
462+
arch.is_family_variant()
463+
&& arch.major_version() == self.major_version()
464+
&& arch.minor_version() <= self.minor_version()
465+
};
466+
467+
if self.is_architecture_variant() {
468+
// Architecture-specific ('a' suffix) features include:
469+
// - all lower-or-equal baseline features
470+
// - all lower-or-equal-with-same-major-version family features
471+
// - itself
472+
NvvmArch::iter()
473+
.filter(|arch| included_baseline(arch) || included_family(arch) || arch == self)
474+
.map(|arch| arch.target_feature())
478475
.collect()
479476
} else if self.is_family_variant() {
480-
// 'f' variants: same major version with equal or higher minor version
477+
// Family-specific ('f' suffix) features include:
478+
// - all lower-or-equal baseline features
479+
// - all lower-or-equal-with-same-major-version family features
481480
NvvmArch::iter()
482-
.filter(|arch| {
483-
// Include base variants with same major and >= minor version
484-
arch.is_base_variant()
485-
&& arch.major_version() == self.major_version()
486-
&& arch.minor_version() >= self.minor_version()
487-
})
481+
.filter(|arch| included_baseline(arch) || included_family(arch))
488482
.map(|arch| arch.target_feature())
489-
.chain(std::iter::once(self.target_feature())) // Add the 'f' variant itself
490483
.collect()
491484
} else {
492-
// Base variants: all base architectures from lower or equal versions
485+
// Baseline (no suffix) features include:
486+
// - all lower-or-equal baseline features
493487
NvvmArch::iter()
494-
.filter(|arch| {
495-
arch.is_base_variant() && arch.capability_value() <= self.capability_value()
496-
})
488+
.filter(included_baseline)
497489
.map(|arch| arch.target_feature())
498490
.collect()
499-
};
500-
501-
features.sort();
502-
features
491+
}
503492
}
504493

505494
/// Create an iterator over all architectures from Compute35 up to and including this one
@@ -780,19 +769,16 @@ mod tests {
780769
fn nvvm_arch_all_target_features() {
781770
use crate::NvvmArch;
782771

783-
// Compute35 only includes itself
784772
assert_eq!(
785773
NvvmArch::Compute35.all_target_features(),
786774
vec!["compute_35"]
787775
);
788776

789-
// Compute50 includes all lower base capabilities
790777
assert_eq!(
791778
NvvmArch::Compute50.all_target_features(),
792779
vec!["compute_35", "compute_37", "compute_50"],
793780
);
794781

795-
// Compute61 includes all lower base capabilities
796782
assert_eq!(
797783
NvvmArch::Compute61.all_target_features(),
798784
vec![
@@ -806,7 +792,6 @@ mod tests {
806792
]
807793
);
808794

809-
// Compute70 includes all lower base capabilities
810795
assert_eq!(
811796
NvvmArch::Compute70.all_target_features(),
812797
vec![
@@ -822,7 +807,6 @@ mod tests {
822807
]
823808
);
824809

825-
// Compute90 includes lower base capabilities
826810
let compute90_features = NvvmArch::Compute90.all_target_features();
827811
assert_eq!(
828812
compute90_features,
@@ -846,9 +830,6 @@ mod tests {
846830
]
847831
);
848832

849-
// Test 'a' variant - includes all available instructions for the architecture.
850-
// This means: all base variants up to same version, no 'f' variants (90 has none), and the
851-
// 'a' variant.
852833
assert_eq!(
853834
NvvmArch::Compute90a.all_target_features(),
854835
vec![
@@ -872,14 +853,9 @@ mod tests {
872853
]
873854
);
874855

875-
// Test compute100a - should include base variants up to 100, and 100f, and itself,
876-
// but NOT 101f or 103f (higher minor).
877856
assert_eq!(
878857
NvvmArch::Compute100a.all_target_features(),
879858
vec![
880-
"compute_100",
881-
"compute_100a",
882-
"compute_100f",
883859
"compute_35",
884860
"compute_37",
885861
"compute_50",
@@ -896,26 +872,39 @@ mod tests {
896872
"compute_87",
897873
"compute_89",
898874
"compute_90",
875+
"compute_100",
876+
"compute_100f",
877+
"compute_100a",
899878
]
900879
);
901880

902-
// Test 'f' variant with 100f
903881
assert_eq!(
904882
NvvmArch::Compute100f.all_target_features(),
905-
// FIXME: this is wrong
906-
vec!["compute_100", "compute_100f", "compute_101", "compute_103"]
883+
vec![
884+
"compute_35",
885+
"compute_37",
886+
"compute_50",
887+
"compute_52",
888+
"compute_53",
889+
"compute_60",
890+
"compute_61",
891+
"compute_62",
892+
"compute_70",
893+
"compute_72",
894+
"compute_75",
895+
"compute_80",
896+
"compute_86",
897+
"compute_87",
898+
"compute_89",
899+
"compute_90",
900+
"compute_100",
901+
"compute_100f",
902+
]
907903
);
908904

909-
// Test compute101a - should include base variants up to 101, and 100f and 101f, and
910-
// itself, but not 103f (higher minor)
911905
assert_eq!(
912906
NvvmArch::Compute101a.all_target_features(),
913907
vec![
914-
"compute_100",
915-
"compute_100f",
916-
"compute_101",
917-
"compute_101a",
918-
"compute_101f",
919908
"compute_35",
920909
"compute_37",
921910
"compute_50",
@@ -932,22 +921,43 @@ mod tests {
932921
"compute_87",
933922
"compute_89",
934923
"compute_90",
924+
"compute_100",
925+
"compute_100f",
926+
"compute_101",
927+
"compute_101f",
928+
"compute_101a",
935929
]
936930
);
937931

938-
// Test 'f' variant with 101f
939932
assert_eq!(
940933
NvvmArch::Compute101f.all_target_features(),
941-
vec!["compute_101", "compute_101f", "compute_103"],
934+
vec![
935+
"compute_35",
936+
"compute_37",
937+
"compute_50",
938+
"compute_52",
939+
"compute_53",
940+
"compute_60",
941+
"compute_61",
942+
"compute_62",
943+
"compute_70",
944+
"compute_72",
945+
"compute_75",
946+
"compute_80",
947+
"compute_86",
948+
"compute_87",
949+
"compute_89",
950+
"compute_90",
951+
"compute_100",
952+
"compute_100f",
953+
"compute_101",
954+
"compute_101f",
955+
]
942956
);
943957

944958
assert_eq!(
945959
NvvmArch::Compute120.all_target_features(),
946960
vec![
947-
"compute_100",
948-
"compute_101",
949-
"compute_103",
950-
"compute_120",
951961
"compute_35",
952962
"compute_37",
953963
"compute_50",
@@ -964,24 +974,43 @@ mod tests {
964974
"compute_87",
965975
"compute_89",
966976
"compute_90",
977+
"compute_100",
978+
"compute_101",
979+
"compute_103",
980+
"compute_120",
967981
]
968982
);
969983

970984
assert_eq!(
971985
NvvmArch::Compute120f.all_target_features(),
972-
// FIXME: this is wrong
973-
vec!["compute_120", "compute_120f", "compute_121"]
974-
);
975-
976-
assert_eq!(
977-
NvvmArch::Compute120a.all_target_features(),
978986
vec![
987+
"compute_35",
988+
"compute_37",
989+
"compute_50",
990+
"compute_52",
991+
"compute_53",
992+
"compute_60",
993+
"compute_61",
994+
"compute_62",
995+
"compute_70",
996+
"compute_72",
997+
"compute_75",
998+
"compute_80",
999+
"compute_86",
1000+
"compute_87",
1001+
"compute_89",
1002+
"compute_90",
9791003
"compute_100",
9801004
"compute_101",
9811005
"compute_103",
9821006
"compute_120",
983-
"compute_120a",
9841007
"compute_120f",
1008+
]
1009+
);
1010+
1011+
assert_eq!(
1012+
NvvmArch::Compute120a.all_target_features(),
1013+
vec![
9851014
"compute_35",
9861015
"compute_37",
9871016
"compute_50",
@@ -998,6 +1027,12 @@ mod tests {
9981027
"compute_87",
9991028
"compute_89",
10001029
"compute_90",
1030+
"compute_100",
1031+
"compute_101",
1032+
"compute_103",
1033+
"compute_120",
1034+
"compute_120f",
1035+
"compute_120a",
10011036
]
10021037
);
10031038
}

guide/src/guide/compute_capabilities.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ CudaBuilder::new("kernels")
7474
.unwrap();
7575

7676
// In your kernel code:
77-
#[cfg(target_feature = "compute_60")] // ✓ Pass (older compute capability)
78-
#[cfg(target_feature = "compute_70")] // ✓ Pass (current compute capability)
79-
#[cfg(target_feature = "compute_80")] // ✗ Fail (newer compute capability)
77+
#[cfg(target_feature = "compute_60")] // ✓ Pass (lower base variant)
78+
#[cfg(target_feature = "compute_70")] // ✓ Pass (this base variant))
79+
#[cfg(target_feature = "compute_80")] // ✗ Fail (higher base variant)
8080
```
8181

8282
### Family Suffix ('f')
@@ -99,13 +99,13 @@ CudaBuilder::new("kernels")
9999
.unwrap();
100100

101101
// In your kernel code:
102-
#[cfg(target_feature = "compute_100")] // ✗ Fail (10.0 < 10.1)
103-
#[cfg(target_feature = "compute_101")] // ✓ Pass (equal major, equal minor)
104-
#[cfg(target_feature = "compute_103")] // ✓ Pass (equal major, greater minor)
102+
#[cfg(target_feature = "compute_90")] // ✓ Pass (lower base variant)
103+
#[cfg(target_feature = "compute_100")] // ✓ Pass (lower base variant)
104+
#[cfg(target_feature = "compute_100f")] // ✓ Pass (lower 'f' variant)
105+
#[cfg(target_feature = "compute_101")] // ✓ Pass (this base variant)
105106
#[cfg(target_feature = "compute_101f")] // ✓ Pass (the 'f' variant itself)
106-
#[cfg(target_feature = "compute_100f")] // ✗ Fail (other 'f' variant)
107-
#[cfg(target_feature = "compute_90")] // ✗ Fail (different major)
108-
#[cfg(target_feature = "compute_110")] // ✗ Fail (different major)
107+
#[cfg(target_feature = "compute_103")] // ✗ Fail (higher base variant)
108+
#[cfg(target_feature = "compute_110")] // ✗ Fail (higher base variant)
109109
```
110110

111111
### Architecture Suffix ('a')
@@ -130,12 +130,12 @@ CudaBuilder::new("kernels")
130130
.unwrap();
131131

132132
// In your kernel code:
133-
#[cfg(target_feature = "compute_100a")] // ✓ Pass (the 'a' variant itself)
134-
#[cfg(target_feature = "compute_100")] // ✓ Pass (base variant)
135133
#[cfg(target_feature = "compute_90")] // ✓ Pass (lower base variant)
134+
#[cfg(target_feature = "compute_100")] // ✓ Pass (base variant)
136135
#[cfg(target_feature = "compute_100f")] // ✓ Pass (family variant with same major/minor)
137-
#[cfg(target_feature = "compute_101f")] // ✗ Fail (family variant with higher minor)
138-
#[cfg(target_feature = "compute_110")] // ✗ Fail (higher major version)
136+
#[cfg(target_feature = "compute_100a")] // ✓ Pass (the 'a' variant itself)
137+
#[cfg(target_feature = "compute_101f")] // ✗ Fail (higher family variant)
138+
#[cfg(target_feature = "compute_110")] // ✗ Fail (higher base variant)
139139
```
140140

141141
Note: While the 'a' variant enables all these features during compilation (allowing you to use all available instructions), the generated PTX code will still only run on the exact GPU architecture specified.

0 commit comments

Comments
 (0)