Skip to content

Commit 1bf3ee5

Browse files
committed
address PR comments
1 parent 7013d36 commit 1bf3ee5

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

guardrails/llm_providers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
487487
try:
488488
from transformers import ( # noqa: F401 # type: ignore
489489
FlaxPreTrainedModel,
490+
GenerationMixin,
490491
PreTrainedModel,
491492
TFPreTrainedModel,
492493
)
@@ -498,14 +499,24 @@ def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
498499
or isinstance(api_self, TFPreTrainedModel)
499500
or isinstance(api_self, FlaxPreTrainedModel)
500501
):
501-
return HuggingFaceModelCallable(*args, model_generate=llm_api, **kwargs)
502+
if (
503+
hasattr(llm_api, "__func__")
504+
and llm_api.__func__ == GenerationMixin.generate
505+
):
506+
return HuggingFaceModelCallable(*args, model_generate=llm_api, **kwargs)
507+
raise ValueError("Only text generation models are supported at this time.")
502508
except ImportError:
503509
pass
504510
try:
505511
from transformers import Pipeline # noqa: F401 # type: ignore
506512

507513
if isinstance(llm_api, Pipeline):
508-
return HuggingFacePipelineCallable(*args, pipeline=llm_api, **kwargs)
514+
# Couldn't find a constant for this
515+
if llm_api.task == "text-generation":
516+
return HuggingFacePipelineCallable(*args, pipeline=llm_api, **kwargs)
517+
raise ValueError(
518+
"Only text generation pipelines are supported at this time."
519+
)
509520
except ImportError:
510521
pass
511522

tests/unit_tests/test_llm_providers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def test_get_llm_ask_anthropic():
650650
not importlib.util.find_spec("transformers"),
651651
reason="transformers is not installed",
652652
)
653-
def test_get_llm_ask_hugging_face_model():
653+
def test_get_llm_ask_hugging_face_model(mocker):
654654
from transformers import PreTrainedModel
655655

656656
from guardrails.llm_providers import HuggingFaceModelCallable
@@ -661,9 +661,6 @@ class MockModel(PreTrainedModel):
661661
def __init__(self, *args, **kwargs):
662662
self._modules = {}
663663

664-
def generate(self, *args, **kwargs):
665-
pass
666-
667664
mock_model = MockModel()
668665

669666
prompt_callable = get_llm_ask(mock_model.generate)
@@ -681,6 +678,8 @@ def test_get_llm_ask_hugging_face_pipeline():
681678
from guardrails.llm_providers import HuggingFacePipelineCallable
682679

683680
class MockPipeline(Pipeline):
681+
task = "text-generation"
682+
684683
def __init__(self, *args, **kwargs):
685684
pass
686685

0 commit comments

Comments
 (0)