Skip to content

Commit 5e11004

Browse files
authored
fix: add a check for int32 indices in sampling.py (#2127)
<!-- .github/pull_request_template.md --> ## 📌 Description New function to validate that the indices type, when provided, is `int32`. To close #2115. There are now two separate functions doing checking in this file. I will move them to the C++ side later when I have some more bandwidth, probably after Thanksgiving. Just a short fix for now. You can close if you'd rather wait for that. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues #2115 <!-- Link any related issues here --> Relevant to the issue. Now running their code: ``` (flashinfer) raayan@uril-1:~/projects/flashinfer$ python test.py tensor([1, 1, 0, 0], device='cuda:0', dtype=torch.int32) Traceback (most recent call last): File "/home/raayan/projects/flashinfer/test.py", line 15, in <module> incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 1031, in top_k_top_p_sampling_from_logits _check_indices_dtype(indices) File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 487, in _check_indices_dtype raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}") ValueError: indices must have dtype torch.int32, got torch.int64 ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Improvements** * Enforced that indices passed to sampling operations must use int32, adding runtime validation before sampling. * **Documentation** * Clarified docstrings to state the int32 requirement for indices parameters. * **Tests** * Updated and expanded tests to cover the new dtype validation paths and related error cases. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>
1 parent 5acb57b commit 5e11004

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

flashinfer/sampling.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ def _to_tensor_scalar_tuple(x):
481481
return (None, x)
482482

483483

484+
def _check_indices_dtype(indices: Optional[torch.Tensor]) -> None:
485+
"""Validate indices dtype."""
486+
if indices is not None and indices.dtype != torch.int32:
487+
raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}")
488+
489+
484490
def _check_tensor_param(param: Any, tensor: torch.Tensor) -> None:
485491
"""Validate sampling parameters."""
486492
if isinstance(param, torch.Tensor):
@@ -576,7 +582,7 @@ def sampling_from_logits(
576582
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
577583
probability distributions.
578584
indices: Optional[torch.Tensor]
579-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in logits.
585+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in logits.
580586
For example, if indices[i] = j, then the i-th output will be sampled from logits[j].
581587
This allows reusing the same probability distribution for multiple outputs.
582588
If indices is not provided, the i-th output will be sampled from the i-th row of logits.
@@ -612,6 +618,7 @@ def sampling_from_logits(
612618
if check_nan:
613619
if torch.any(torch.isnan(logits)):
614620
raise ValueError("Input logits contains NaN.")
621+
_check_indices_dtype(indices)
615622
return get_sampling_module().sampling_from_logits(
616623
logits, indices, deterministic, generator
617624
)
@@ -634,7 +641,7 @@ def sampling_from_probs(
634641
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
635642
probability distributions.
636643
indices: Optional[torch.Tensor]
637-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
644+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs.
638645
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
639646
This allows reusing the same probability distribution for multiple outputs.
640647
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
@@ -676,6 +683,7 @@ def sampling_from_probs(
676683
if check_nan:
677684
if torch.any(torch.isnan(probs)):
678685
raise ValueError("Input probs contains NaN.")
686+
_check_indices_dtype(indices)
679687
return get_sampling_module().sampling_from_probs(
680688
probs, indices, deterministic, generator
681689
)
@@ -708,7 +716,7 @@ def top_p_sampling_from_probs(
708716
If a float, the same threshold is used for all requests.
709717
If a tensor, each request has its own threshold.
710718
indices: Optional[torch.Tensor]
711-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
719+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs.
712720
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
713721
This allows reusing the same probability distribution for multiple outputs.
714722
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
@@ -758,6 +766,7 @@ def top_p_sampling_from_probs(
758766
if check_nan:
759767
if torch.any(torch.isnan(probs)):
760768
raise ValueError("Input probs contains NaN.")
769+
_check_indices_dtype(indices)
761770
_check_tensor_param(top_p, probs)
762771
return get_sampling_module().top_p_sampling_from_probs(
763772
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
@@ -791,7 +800,7 @@ def top_k_sampling_from_probs(
791800
If a scalar, the same threshold is used for all requests.
792801
If a tensor, each request has its own threshold.
793802
indices: Optional[torch.Tensor]
794-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
803+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs.
795804
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
796805
This allows reusing the same probability distribution for multiple outputs.
797806
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
@@ -841,6 +850,7 @@ def top_k_sampling_from_probs(
841850
if check_nan:
842851
if torch.any(torch.isnan(probs)):
843852
raise ValueError("Input probs contains NaN.")
853+
_check_indices_dtype(indices)
844854
_check_tensor_param(top_k, probs)
845855
return get_sampling_module().top_k_sampling_from_probs(
846856
probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator
@@ -875,7 +885,7 @@ def min_p_sampling_from_probs(
875885
If a scalar, the same threshold is used for all requests.
876886
If a tensor, each request has its own threshold.
877887
indices: Optional[torch.Tensor]
878-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
888+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs.
879889
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
880890
This allows reusing the same probability distribution for multiple outputs.
881891
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
@@ -920,6 +930,7 @@ def min_p_sampling_from_probs(
920930
if check_nan:
921931
if torch.any(torch.isnan(probs)):
922932
raise ValueError("Input probs contains NaN.")
933+
_check_indices_dtype(indices)
923934
_check_tensor_param(min_p, probs)
924935
return get_sampling_module().min_p_sampling_from_probs(
925936
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
@@ -960,7 +971,7 @@ def top_k_top_p_sampling_from_logits(
960971
If a scalar, the same threshold is used for all requests.
961972
If a tensor, each request has its own threshold.
962973
indices: Optional[torch.Tensor]
963-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
974+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs.
964975
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
965976
This allows reusing the same probability distribution for multiple outputs.
966977
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
@@ -1018,6 +1029,7 @@ def top_k_top_p_sampling_from_logits(
10181029
top_k_mask_logits
10191030
top_p_sampling_from_probs
10201031
"""
1032+
_check_indices_dtype(indices)
10211033
_check_tensor_param(top_k, logits)
10221034
_check_tensor_param(top_p, logits)
10231035
if filter_apply_order == "top_k_first":
@@ -1082,7 +1094,7 @@ def top_k_top_p_sampling_from_probs(
10821094
If a scalar, the same threshold is used for all requests.
10831095
If a tensor, each request has its own threshold.
10841096
indices: Optional[torch.Tensor]
1085-
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
1097+
Optional indices tensor of shape ``(batch_size,)``, dtype ``torch.int32`` that maps each output to a row in probs.
10861098
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
10871099
This allows reusing the same probability distribution for multiple outputs.
10881100
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
@@ -1135,6 +1147,7 @@ def top_k_top_p_sampling_from_probs(
11351147
top_p_renorm_probs
11361148
top_k_mask_logits
11371149
"""
1150+
_check_indices_dtype(indices)
11381151
_check_tensor_param(top_k, probs)
11391152
_check_tensor_param(top_p, probs)
11401153
if filter_apply_order == "top_k_first":

tests/utils/test_sampling.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def test_chain_speculative_sampling(
572572
@pytest.mark.parametrize("batch_size", [1, 99, 989])
573573
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
574574
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
575-
def test_check_tensor_param_min_p(batch_size, vocab_size, p):
575+
def test_tensor_validation_min_p(batch_size, vocab_size, p):
576576
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
577577
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
578578

@@ -587,7 +587,7 @@ def test_check_tensor_param_min_p(batch_size, vocab_size, p):
587587
flashinfer.sampling.min_p_sampling_from_probs(
588588
normalized_prob,
589589
torch.tensor(
590-
[[p] * vocab_size] * batch_size, dtype=torch.int, device="cuda:0"
590+
[[p] * vocab_size] * batch_size, dtype=torch.float32, device="cuda:0"
591591
),
592592
)
593593

@@ -597,22 +597,33 @@ def test_check_tensor_param_min_p(batch_size, vocab_size, p):
597597
match=r"Expected a 1D tensor of shape \(batch_size,\) or scalar.*got a 0-dimensional tensor",
598598
):
599599
flashinfer.sampling.min_p_sampling_from_probs(
600-
normalized_prob, torch.tensor(p, dtype=torch.int, device="cuda:0")
600+
normalized_prob, torch.tensor(p, dtype=torch.float32, device="cuda:0")
601601
)
602602

603-
# 4: 1D tensor with a broken batch size raises error (only when batch_size > 1).
603+
# 4: non-int32 indices raises error.
604+
with pytest.raises(
605+
ValueError,
606+
match=r"indices must have dtype torch\.int32, got torch\.int64",
607+
):
608+
flashinfer.sampling.min_p_sampling_from_probs(
609+
normalized_prob,
610+
torch.tensor([p] * batch_size, dtype=torch.float32, device="cuda:0"),
611+
torch.tensor([p] * batch_size, dtype=torch.int64, device="cuda:0"),
612+
)
613+
614+
# 5: 1D tensor with a broken batch size raises error (only when batch_size > 1).
604615
if batch_size > 1:
605616
with pytest.raises(
606617
ValueError, match="Sampling parameter tensor batch size mismatch"
607618
):
608619
flashinfer.sampling.min_p_sampling_from_probs(
609-
normalized_prob, torch.tensor([p], dtype=torch.int, device="cuda:0")
620+
normalized_prob, torch.tensor([p], dtype=torch.float32, device="cuda:0")
610621
)
611622

612-
# 5: 1D tensor with the correct batch size works.
623+
# 6: 1D tensor with the correct batch size works.
613624
samples = flashinfer.sampling.min_p_sampling_from_probs(
614625
normalized_prob,
615-
torch.tensor([p] * batch_size, dtype=torch.int, device="cuda:0"),
626+
torch.tensor([p] * batch_size, dtype=torch.float32, device="cuda:0"),
616627
)
617628
assert samples.shape == (batch_size,)
618629

0 commit comments

Comments
 (0)