@@ -430,73 +430,62 @@ impl NvvmArch {
430430 }
431431 }
432432
433- /// Get all target features up to and including this architecture.
433+ /// Gets all target features up to and including this architecture. This effectively answers
434+ /// the question "for a given compilation target, what architectural features can be used?"
434435 ///
435- /// # PTX Forward-Compatibility Rules (per NVIDIA documentation):
436+ /// # Examples
436437 ///
437- /// - **No suffix** (compute_XX): PTX is forward-compatible across all future architectures.
438- /// Example: compute_70 runs on CC 7.0, 8.x, 9.x, 10.x, 12.x, and all future GPUs.
438+ /// ```
439+ /// # use nvvm::NvvmArch;
440+ /// let features = NvvmArch::Compute53.all_target_features();
441+ /// assert_eq!(
442+ /// features,
443+ /// vec!["compute_35", "compute_37", "compute_50", "compute_52", "compute_53"]
444+ /// );
445+ /// ```
439446 ///
440- /// - **Family-specific 'f' suffix** (compute_XXf): Forward-compatible within the same major
441- /// version family. Supports devices with same major CC and equal or higher minor CC.
442- /// Example: compute_100f runs on CC 10.0, 10.3, and future 10.x devices, but NOT on 11.x.
443- ///
444- /// - **Architecture-specific 'a' suffix** (compute_XXa): The code only runs on GPUs of that
445- /// specific CC and no others. No forward or backward compatibility whatsoever.
446- /// These features are primarily related to Tensor Core programming.
447- /// Example: compute_100a ONLY runs on CC 10.0, not on 10.3, 10.1, 9.0, or any other version.
447+ /// # External resources
448448 ///
449449 /// For more details on family and architecture-specific features, see:
450450 /// <https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/>
451451 pub fn all_target_features ( & self ) -> Vec < String > {
452- let mut features: Vec < String > = if self . is_architecture_variant ( ) {
453- // 'a' variants: include all available instructions for the architecture
454- // This means: all base variants up to same version, all 'f' variants with same major and <= minor, plus itself
455- let base_features: Vec < String > = NvvmArch :: iter ( )
456- . filter ( |arch| {
457- arch. is_base_variant ( ) && arch. capability_value ( ) <= self . capability_value ( )
458- } )
459- . map ( |arch| arch. target_feature ( ) )
460- . collect ( ) ;
461-
462- let family_features: Vec < String > = NvvmArch :: iter ( )
463- . filter ( |arch| {
464- arch. is_family_variant ( )
465- && arch. major_version ( ) == self . major_version ( )
466- && arch. minor_version ( ) <= self . minor_version ( )
467- } )
468- . map ( |arch| arch. target_feature ( ) )
469- . collect ( ) ;
452+ // All lower-or-equal baseline features are included.
453+ let included_baseline = |arch : & NvvmArch | {
454+ arch. is_base_variant ( ) && arch. capability_value ( ) <= self . capability_value ( )
455+ } ;
470456
471- base_features
472- . into_iter ( )
473- . chain ( family_features)
474- . chain ( std:: iter:: once ( self . target_feature ( ) ) )
457+ // All lower-or-equal-with-same-major-version family features are included.
458+ let included_family = |arch : & NvvmArch | {
459+ arch. is_family_variant ( )
460+ && arch. major_version ( ) == self . major_version ( )
461+ && arch. minor_version ( ) <= self . minor_version ( )
462+ } ;
463+
464+ if self . is_architecture_variant ( ) {
465+ // Architecture-specific ('a' suffix) features include:
466+ // - all lower-or-equal baseline features
467+ // - all lower-or-equal-with-same-major-version family features
468+ // - itself
469+ NvvmArch :: iter ( )
470+ . filter ( |arch| included_baseline ( arch) || included_family ( arch) || arch == self )
471+ . map ( |arch| arch. target_feature ( ) )
475472 . collect ( )
476473 } else if self . is_family_variant ( ) {
477- // 'f' variants: same major version with equal or higher minor version
474+ // Family-specific ('f' suffix) features include:
475+ // - all lower-or-equal baseline features
476+ // - all lower-or-equal-with-same-major-version family features
478477 NvvmArch :: iter ( )
479- . filter ( |arch| {
480- // Include base variants with same major and >= minor version
481- arch. is_base_variant ( )
482- && arch. major_version ( ) == self . major_version ( )
483- && arch. minor_version ( ) >= self . minor_version ( )
484- } )
478+ . filter ( |arch| included_baseline ( arch) || included_family ( arch) )
485479 . map ( |arch| arch. target_feature ( ) )
486- . chain ( std:: iter:: once ( self . target_feature ( ) ) ) // Add the 'f' variant itself
487480 . collect ( )
488481 } else {
489- // Base variants: all base architectures from lower or equal versions
482+ // Baseline (no suffix) features include:
483+ // - all lower-or-equal baseline features
490484 NvvmArch :: iter ( )
491- . filter ( |arch| {
492- arch. is_base_variant ( ) && arch. capability_value ( ) <= self . capability_value ( )
493- } )
485+ . filter ( |arch| included_baseline ( arch) )
494486 . map ( |arch| arch. target_feature ( ) )
495487 . collect ( )
496- } ;
497-
498- features. sort ( ) ;
499- features
488+ }
500489 }
501490
502491 /// Create an iterator over all architectures from Compute35 up to and including this one
@@ -777,19 +766,16 @@ mod tests {
777766 fn nvvm_arch_all_target_features ( ) {
778767 use crate :: NvvmArch ;
779768
780- // Compute35 only includes itself
781769 assert_eq ! (
782770 NvvmArch :: Compute35 . all_target_features( ) ,
783771 vec![ "compute_35" ]
784772 ) ;
785773
786- // Compute50 includes all lower base capabilities
787774 assert_eq ! (
788775 NvvmArch :: Compute50 . all_target_features( ) ,
789776 vec![ "compute_35" , "compute_37" , "compute_50" ] ,
790777 ) ;
791778
792- // Compute61 includes all lower base capabilities
793779 assert_eq ! (
794780 NvvmArch :: Compute61 . all_target_features( ) ,
795781 vec![
@@ -803,7 +789,6 @@ mod tests {
803789 ]
804790 ) ;
805791
806- // Compute70 includes all lower base capabilities
807792 assert_eq ! (
808793 NvvmArch :: Compute70 . all_target_features( ) ,
809794 vec![
@@ -819,7 +804,6 @@ mod tests {
819804 ]
820805 ) ;
821806
822- // Compute90 includes lower base capabilities
823807 let compute90_features = NvvmArch :: Compute90 . all_target_features ( ) ;
824808 assert_eq ! (
825809 compute90_features,
@@ -843,9 +827,6 @@ mod tests {
843827 ]
844828 ) ;
845829
846- // Test 'a' variant - includes all available instructions for the architecture.
847- // This means: all base variants up to same version, no 'f' variants (90 has none), and the
848- // 'a' variant.
849830 assert_eq ! (
850831 NvvmArch :: Compute90a . all_target_features( ) ,
851832 vec![
@@ -869,14 +850,9 @@ mod tests {
869850 ]
870851 ) ;
871852
872- // Test compute100a - should include base variants up to 100, and 100f, and itself,
873- // but NOT 101f or 103f (higher minor).
874853 assert_eq ! (
875854 NvvmArch :: Compute100a . all_target_features( ) ,
876855 vec![
877- "compute_100" ,
878- "compute_100a" ,
879- "compute_100f" ,
880856 "compute_35" ,
881857 "compute_37" ,
882858 "compute_50" ,
@@ -893,26 +869,39 @@ mod tests {
893869 "compute_87" ,
894870 "compute_89" ,
895871 "compute_90" ,
872+ "compute_100" ,
873+ "compute_100f" ,
874+ "compute_100a" ,
896875 ]
897876 ) ;
898877
899- // Test 'f' variant with 100f
900878 assert_eq ! (
901879 NvvmArch :: Compute100f . all_target_features( ) ,
902- // FIXME: this is wrong
903- vec![ "compute_100" , "compute_100f" , "compute_101" , "compute_103" ]
880+ vec![
881+ "compute_35" ,
882+ "compute_37" ,
883+ "compute_50" ,
884+ "compute_52" ,
885+ "compute_53" ,
886+ "compute_60" ,
887+ "compute_61" ,
888+ "compute_62" ,
889+ "compute_70" ,
890+ "compute_72" ,
891+ "compute_75" ,
892+ "compute_80" ,
893+ "compute_86" ,
894+ "compute_87" ,
895+ "compute_89" ,
896+ "compute_90" ,
897+ "compute_100" ,
898+ "compute_100f" ,
899+ ]
904900 ) ;
905901
906- // Test compute101a - should include base variants up to 101, and 100f and 101f, and
907- // itself, but not 103f (higher minor)
908902 assert_eq ! (
909903 NvvmArch :: Compute101a . all_target_features( ) ,
910904 vec![
911- "compute_100" ,
912- "compute_100f" ,
913- "compute_101" ,
914- "compute_101a" ,
915- "compute_101f" ,
916905 "compute_35" ,
917906 "compute_37" ,
918907 "compute_50" ,
@@ -929,22 +918,43 @@ mod tests {
929918 "compute_87" ,
930919 "compute_89" ,
931920 "compute_90" ,
921+ "compute_100" ,
922+ "compute_100f" ,
923+ "compute_101" ,
924+ "compute_101f" ,
925+ "compute_101a" ,
932926 ]
933927 ) ;
934928
935- // Test 'f' variant with 101f
936929 assert_eq ! (
937930 NvvmArch :: Compute101f . all_target_features( ) ,
938- vec![ "compute_101" , "compute_101f" , "compute_103" ] ,
931+ vec![
932+ "compute_35" ,
933+ "compute_37" ,
934+ "compute_50" ,
935+ "compute_52" ,
936+ "compute_53" ,
937+ "compute_60" ,
938+ "compute_61" ,
939+ "compute_62" ,
940+ "compute_70" ,
941+ "compute_72" ,
942+ "compute_75" ,
943+ "compute_80" ,
944+ "compute_86" ,
945+ "compute_87" ,
946+ "compute_89" ,
947+ "compute_90" ,
948+ "compute_100" ,
949+ "compute_100f" ,
950+ "compute_101" ,
951+ "compute_101f" ,
952+ ]
939953 ) ;
940954
941955 assert_eq ! (
942956 NvvmArch :: Compute120 . all_target_features( ) ,
943957 vec![
944- "compute_100" ,
945- "compute_101" ,
946- "compute_103" ,
947- "compute_120" ,
948958 "compute_35" ,
949959 "compute_37" ,
950960 "compute_50" ,
@@ -961,24 +971,43 @@ mod tests {
961971 "compute_87" ,
962972 "compute_89" ,
963973 "compute_90" ,
974+ "compute_100" ,
975+ "compute_101" ,
976+ "compute_103" ,
977+ "compute_120" ,
964978 ]
965979 ) ;
966980
967981 assert_eq ! (
968982 NvvmArch :: Compute120f . all_target_features( ) ,
969- // FIXME: this is wrong
970- vec![ "compute_120" , "compute_120f" , "compute_121" ]
971- ) ;
972-
973- assert_eq ! (
974- NvvmArch :: Compute120a . all_target_features( ) ,
975983 vec![
984+ "compute_35" ,
985+ "compute_37" ,
986+ "compute_50" ,
987+ "compute_52" ,
988+ "compute_53" ,
989+ "compute_60" ,
990+ "compute_61" ,
991+ "compute_62" ,
992+ "compute_70" ,
993+ "compute_72" ,
994+ "compute_75" ,
995+ "compute_80" ,
996+ "compute_86" ,
997+ "compute_87" ,
998+ "compute_89" ,
999+ "compute_90" ,
9761000 "compute_100" ,
9771001 "compute_101" ,
9781002 "compute_103" ,
9791003 "compute_120" ,
980- "compute_120a" ,
9811004 "compute_120f" ,
1005+ ]
1006+ ) ;
1007+
1008+ assert_eq ! (
1009+ NvvmArch :: Compute120a . all_target_features( ) ,
1010+ vec![
9821011 "compute_35" ,
9831012 "compute_37" ,
9841013 "compute_50" ,
@@ -995,6 +1024,12 @@ mod tests {
9951024 "compute_87" ,
9961025 "compute_89" ,
9971026 "compute_90" ,
1027+ "compute_100" ,
1028+ "compute_101" ,
1029+ "compute_103" ,
1030+ "compute_120" ,
1031+ "compute_120f" ,
1032+ "compute_120a" ,
9981033 ]
9991034 ) ;
10001035 }
0 commit comments