diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 868ca51b3331..12c73f2d79f7 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -56,15 +56,22 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): ) expected_lora_output = [ - "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501 - "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501 - "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501 + [ + "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])" # noqa: E501 + ], + [ + "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501 + "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501 + ], + [ + "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])" # noqa: E501 + ], ] - assert ( - do_sample(llm, mixtral_lora_files, lora_id=1, prompts=prompts) - == expected_lora_output - ) - assert ( - do_sample(llm, mixtral_lora_files, lora_id=2, prompts=prompts) - == expected_lora_output - ) + + def check_outputs(generated: list[str]): + assert len(generated) == len(expected_lora_output) + for gen, gt_choices in zip(generated, expected_lora_output): + assert gen in gt_choices + + check_outputs(do_sample(llm, mixtral_lora_files, lora_id=1, prompts=prompts)) + check_outputs(do_sample(llm, mixtral_lora_files, lora_id=2, prompts=prompts))