Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.utils import handle_stop_sequences


eval_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -95,7 +94,10 @@ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
tmp[choices["index"]] = choices["text"]
x = ""
if choices["text"] is not None:
x = choices["text"]
tmp[choices["index"]] = x
res = res + tmp
return res

Expand Down Expand Up @@ -167,7 +169,10 @@ def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]:
tmp[choices["index"]] = choices["message"]["content"]
x = ""
if choices["message"]["content"] is not None:
x = choices["message"]["content"]
tmp[choices["index"]] = x
res = res + tmp
return res

Expand Down
24 changes: 24 additions & 0 deletions tests/models/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from lm_eval.api.instance import Instance
from lm_eval.models.openai_completions import LocalCompletionsAPI


Expand Down Expand Up @@ -161,6 +162,29 @@ def test_model_tokenized_call_usage(
assert result == {"result": "success"}


def test_generate_until_with_null_message_content(api):
with patch("requests.post") as mock_post:
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{
"index": 0,
"text": None,
}
]
}
mock_response.ok = True
mock_post.return_value = mock_response
request = Instance(
request_type="generate_until",
doc={},
arguments=("Test prompt", {"max_gen_toks": 10}),
idx=0,
)

_ = api.generate_until([request])


class DummyAsyncContextManager:
def __init__(self, result):
self.result = result
Expand Down
Loading