@@ -487,6 +487,7 @@ def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
487487 try :
488488 from transformers import ( # noqa: F401 # type: ignore
489489 FlaxPreTrainedModel ,
490+ GenerationMixin ,
490491 PreTrainedModel ,
491492 TFPreTrainedModel ,
492493 )
@@ -498,14 +499,24 @@ def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
498499 or isinstance (api_self , TFPreTrainedModel )
499500 or isinstance (api_self , FlaxPreTrainedModel )
500501 ):
501- return HuggingFaceModelCallable (* args , model_generate = llm_api , ** kwargs )
502+ if (
503+ hasattr (llm_api , "__func__" )
504+ and llm_api .__func__ == GenerationMixin .generate
505+ ):
506+ return HuggingFaceModelCallable (* args , model_generate = llm_api , ** kwargs )
507+ raise ValueError ("Only text generation models are supported at this time." )
502508 except ImportError :
503509 pass
504510 try :
505511 from transformers import Pipeline # noqa: F401 # type: ignore
506512
507513 if isinstance (llm_api , Pipeline ):
508- return HuggingFacePipelineCallable (* args , pipeline = llm_api , ** kwargs )
514+ # Couldn't find a constant for this
515+ if llm_api .task == "text-generation" :
516+ return HuggingFacePipelineCallable (* args , pipeline = llm_api , ** kwargs )
517+ raise ValueError (
518+ "Only text generation pipelines are supported at this time."
519+ )
509520 except ImportError :
510521 pass
511522
0 commit comments