Commit 5e11004
authored
fix: add a check for int32 indices in sampling.py (#2127)
<!-- .github/pull_request_template.md -->
## 📌 Description
New function to validate that the indices type, when provided, is
`int32`. To close
#2115.
There are now two separate functions doing checking in this file. I will
move them to the C++ side later when I have some more bandwidth,
probably after Thanksgiving. Just a short fix for now. You can close if
you'd rather wait for that.
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
## 🔍 Related Issues
#2115
<!-- Link any related issues here -->
Relevant to the issue. Now running their code:
```
(flashinfer) raayan@uril-1:~/projects/flashinfer$ python test.py
tensor([1, 1, 0, 0], device='cuda:0', dtype=torch.int32)
Traceback (most recent call last):
File "/home/raayan/projects/flashinfer/test.py", line 15, in <module>
incorrect_samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 1031, in top_k_top_p_sampling_from_logits
_check_indices_dtype(indices)
File "/home/raayan/projects/flashinfer/flashinfer/sampling.py", line 487, in _check_indices_dtype
raise ValueError(f"indices must have dtype torch.int32, got {indices.dtype}")
ValueError: indices must have dtype torch.int32, got torch.int64
```
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Improvements**
* Enforced that indices passed to sampling operations must use int32,
adding runtime validation before sampling.
* **Documentation**
* Clarified docstrings to state the int32 requirement for indices
parameters.
* **Tests**
* Updated and expanded tests to cover the new dtype validation paths and
related error cases.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com>1 parent 5acb57b commit 5e11004
2 files changed
+38
-14
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
481 | 481 | | |
482 | 482 | | |
483 | 483 | | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
484 | 490 | | |
485 | 491 | | |
486 | 492 | | |
| |||
576 | 582 | | |
577 | 583 | | |
578 | 584 | | |
579 | | - | |
| 585 | + | |
580 | 586 | | |
581 | 587 | | |
582 | 588 | | |
| |||
612 | 618 | | |
613 | 619 | | |
614 | 620 | | |
| 621 | + | |
615 | 622 | | |
616 | 623 | | |
617 | 624 | | |
| |||
634 | 641 | | |
635 | 642 | | |
636 | 643 | | |
637 | | - | |
| 644 | + | |
638 | 645 | | |
639 | 646 | | |
640 | 647 | | |
| |||
676 | 683 | | |
677 | 684 | | |
678 | 685 | | |
| 686 | + | |
679 | 687 | | |
680 | 688 | | |
681 | 689 | | |
| |||
708 | 716 | | |
709 | 717 | | |
710 | 718 | | |
711 | | - | |
| 719 | + | |
712 | 720 | | |
713 | 721 | | |
714 | 722 | | |
| |||
758 | 766 | | |
759 | 767 | | |
760 | 768 | | |
| 769 | + | |
761 | 770 | | |
762 | 771 | | |
763 | 772 | | |
| |||
791 | 800 | | |
792 | 801 | | |
793 | 802 | | |
794 | | - | |
| 803 | + | |
795 | 804 | | |
796 | 805 | | |
797 | 806 | | |
| |||
841 | 850 | | |
842 | 851 | | |
843 | 852 | | |
| 853 | + | |
844 | 854 | | |
845 | 855 | | |
846 | 856 | | |
| |||
875 | 885 | | |
876 | 886 | | |
877 | 887 | | |
878 | | - | |
| 888 | + | |
879 | 889 | | |
880 | 890 | | |
881 | 891 | | |
| |||
920 | 930 | | |
921 | 931 | | |
922 | 932 | | |
| 933 | + | |
923 | 934 | | |
924 | 935 | | |
925 | 936 | | |
| |||
960 | 971 | | |
961 | 972 | | |
962 | 973 | | |
963 | | - | |
| 974 | + | |
964 | 975 | | |
965 | 976 | | |
966 | 977 | | |
| |||
1018 | 1029 | | |
1019 | 1030 | | |
1020 | 1031 | | |
| 1032 | + | |
1021 | 1033 | | |
1022 | 1034 | | |
1023 | 1035 | | |
| |||
1082 | 1094 | | |
1083 | 1095 | | |
1084 | 1096 | | |
1085 | | - | |
| 1097 | + | |
1086 | 1098 | | |
1087 | 1099 | | |
1088 | 1100 | | |
| |||
1135 | 1147 | | |
1136 | 1148 | | |
1137 | 1149 | | |
| 1150 | + | |
1138 | 1151 | | |
1139 | 1152 | | |
1140 | 1153 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
572 | 572 | | |
573 | 573 | | |
574 | 574 | | |
575 | | - | |
| 575 | + | |
576 | 576 | | |
577 | 577 | | |
578 | 578 | | |
| |||
587 | 587 | | |
588 | 588 | | |
589 | 589 | | |
590 | | - | |
| 590 | + | |
591 | 591 | | |
592 | 592 | | |
593 | 593 | | |
| |||
597 | 597 | | |
598 | 598 | | |
599 | 599 | | |
600 | | - | |
| 600 | + | |
601 | 601 | | |
602 | 602 | | |
603 | | - | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
604 | 615 | | |
605 | 616 | | |
606 | 617 | | |
607 | 618 | | |
608 | 619 | | |
609 | | - | |
| 620 | + | |
610 | 621 | | |
611 | 622 | | |
612 | | - | |
| 623 | + | |
613 | 624 | | |
614 | 625 | | |
615 | | - | |
| 626 | + | |
616 | 627 | | |
617 | 628 | | |
618 | 629 | | |
| |||
0 commit comments