Skip to content

Commit cf18c6a

Browse files
authored
Merge pull request #505 from guardrails-ai/hf-support
HuggingFace Support
2 parents 67529d1 + 1bf3ee5 commit cf18c6a

File tree

3 files changed

+476
-2
lines changed

3 files changed

+476
-2
lines changed

docs/llm_api_wrappers.md

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ guard = gd.Guard.from_rail(...)
8686
anthropic_client = Anthropic(api_key="my_api_key")
8787

8888
# Wrap Anthropic API call
89-
raw_llm_output, guardrail_output = guard(
89+
raw_llm_output, guardrail_output, *rest = guard(
9090
anthropic_client.completions.create,
9191
prompt_params={
9292
"prompt_param_1": "value_1",
@@ -100,6 +100,108 @@ raw_llm_output, guardrail_output = guard(
100100
```
101101

102102

103+
## Hugging Face
104+
105+
### Text Generation Models
106+
```py
107+
from guardrails import Guard
108+
from guardrails.validators import ValidLength, ToxicLanguage
109+
import torch
110+
from transformers import AutoModelForCausalLM, AutoTokenizer
111+
112+
113+
# Create your prompt or starting text
114+
prompt = "Hello, I'm a language model,"
115+
116+
# Setup torch
117+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
118+
119+
# Instantiate your tokenizer
120+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
121+
122+
# Instantiate your model
123+
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)
124+
125+
# Customize your model inputs if desired.
126+
# If you don't pass and inputs (`input_ids`, `input_values`, `input_features`, or `pixel_values`)
127+
# We'll try to do something similar to below using the tokenizer and the prompt.
128+
# We strongly suggest passing in your own inputs.
129+
model_inputs = tokenizer(prompt, return_tensors="pt").to(torch_device)
130+
131+
132+
# Create the Guard
133+
guard = Guard.from_string(
134+
validators=[
135+
ValidLength(
136+
min=48,
137+
on_fail="fix"
138+
),
139+
ToxicLanguage(
140+
on_fail="fix"
141+
)
142+
],
143+
prompt=prompt
144+
)
145+
146+
# Run the Guard
147+
response = guard(
148+
llm_api=model.generate,
149+
max_new_tokens=40,
150+
tokenizer=tokenizer,
151+
**model_inputs,
152+
)
153+
154+
# Check the output
155+
if response.validation_passed:
156+
print("validated_output: ", response.validated_output)
157+
else:
158+
print("error: ", response.error)
159+
160+
```
161+
162+
### Pipelines
163+
```py
164+
from guardrails import Guard
165+
from guardrails.validators import ValidLength, ToxicLanguage
166+
import torch
167+
from transformers import pipeline
168+
169+
170+
# Create your prompt or starting text
171+
prompt = "What are we having for dinner?"
172+
173+
# Setup pipeline
174+
generator = pipeline("text-generation", model="facebook/opt-350m")
175+
176+
177+
# Create the Guard
178+
guard = Guard.from_string(
179+
validators=[
180+
ValidLength(
181+
min=48,
182+
on_fail="fix"
183+
),
184+
ToxicLanguage(
185+
on_fail="fix"
186+
)
187+
],
188+
prompt=prompt
189+
)
190+
191+
# Run the Guard
192+
response = guard(
193+
llm_api=generator,
194+
max_new_tokens=40
195+
)
196+
197+
if response.validation_passed:
198+
print("validated_output: ", response.validated_output)
199+
else:
200+
print("error: ", response.error)
201+
202+
```
203+
204+
103205
## Using Manifest
104206
[Manifest](https://github.com/HazyResearch/manifest) is a wrapper around most model APIs and supports hosting local models. It can be used as a LLM API.
105207

guardrails/llm_providers.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pydantic import BaseModel
44

5+
from guardrails.utils.exception_utils import UserFacingException
56
from guardrails.utils.llm_response import LLMResponse
67
from guardrails.utils.openai_utils import (
78
AsyncOpenAIClient,
@@ -12,6 +13,7 @@
1213
get_static_openai_create_func,
1314
)
1415
from guardrails.utils.pydantic_utils import convert_pydantic_model_to_openai_fn
16+
from guardrails.utils.safe_get import safe_get
1517

1618

1719
class PromptCallableException(Exception):
@@ -287,6 +289,124 @@ def _invoke_llm(
287289
return LLMResponse(output=anthropic_response.completion)
288290

289291

292+
class HuggingFaceModelCallable(PromptCallableBase):
293+
def _invoke_llm(
294+
self, prompt: str, model_generate: Any, *args, **kwargs
295+
) -> LLMResponse:
296+
try:
297+
import transformers # noqa: F401 # type: ignore
298+
except ImportError:
299+
raise PromptCallableException(
300+
"The `transformers` package is not installed. "
301+
"Install with `pip install transformers`"
302+
)
303+
try:
304+
import torch
305+
except ImportError:
306+
raise PromptCallableException(
307+
"The `torch` package is not installed. "
308+
"Install with `pip install torch`"
309+
)
310+
311+
tokenizer = kwargs.pop("tokenizer")
312+
if not tokenizer:
313+
raise UserFacingException(
314+
ValueError(
315+
"'tokenizer' must be provided in order to use Hugging Face models!"
316+
)
317+
)
318+
319+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
320+
321+
return_tensors = kwargs.pop("return_tensors", "pt")
322+
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
323+
324+
input_ids = kwargs.pop("input_ids", None)
325+
input_values = kwargs.pop("input_values", None)
326+
input_features = kwargs.pop("input_features", None)
327+
pixel_values = kwargs.pop("pixel_values", None)
328+
model_inputs = kwargs.pop("model_inputs", {})
329+
if (
330+
input_ids is None
331+
and input_values is None
332+
and input_features is None
333+
and pixel_values is None
334+
and not model_inputs
335+
):
336+
model_inputs = tokenizer(prompt, return_tensors=return_tensors).to(
337+
torch_device
338+
)
339+
else:
340+
model_inputs["input_ids"] = input_ids
341+
model_inputs["input_values"] = input_values
342+
model_inputs["input_features"] = input_features
343+
model_inputs["pixel_values"] = pixel_values
344+
345+
do_sample = kwargs.pop("do_sample", None)
346+
temperature = kwargs.pop("temperature", None)
347+
if not do_sample and temperature == 0:
348+
temperature = None
349+
350+
model_inputs["do_sample"] = do_sample
351+
model_inputs["temperature"] = temperature
352+
353+
output = model_generate(
354+
**model_inputs,
355+
**kwargs,
356+
)
357+
358+
# NOTE: This is currently restricted to single outputs
359+
# Should we choose to support multiple return sequences,
360+
# We would need to either validate all of them
361+
# and choose the one with the least failures,
362+
# or accept a selection function
363+
decoded_output = tokenizer.decode(
364+
output[0], skip_special_tokens=skip_special_tokens
365+
)
366+
367+
return LLMResponse(output=decoded_output)
368+
369+
370+
class HuggingFacePipelineCallable(PromptCallableBase):
371+
def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMResponse:
372+
try:
373+
import transformers # noqa: F401 # type: ignore
374+
except ImportError:
375+
raise PromptCallableException(
376+
"The `transformers` package is not installed. "
377+
"Install with `pip install transformers`"
378+
)
379+
try:
380+
import torch # noqa: F401 # type: ignore
381+
except ImportError:
382+
raise PromptCallableException(
383+
"The `torch` package is not installed. "
384+
"Install with `pip install torch`"
385+
)
386+
387+
content_key = kwargs.pop("content_key", "generated_text")
388+
389+
temperature = kwargs.pop("temperature", None)
390+
if temperature == 0:
391+
temperature = None
392+
393+
output = pipeline(
394+
prompt,
395+
temperature=temperature,
396+
*args,
397+
**kwargs,
398+
)
399+
400+
# NOTE: This is currently restricted to single outputs
401+
# Should we choose to support multiple return sequences,
402+
# We would need to either validate all of them
403+
# and choose the one with the least failures,
404+
# or accept a selection function
405+
content = safe_get(output[0], content_key)
406+
407+
return LLMResponse(output=content)
408+
409+
290410
class ArbitraryCallable(PromptCallableBase):
291411
def __init__(self, llm_api: Callable, *args, **kwargs):
292412
self.llm_api = llm_api
@@ -364,6 +484,42 @@ def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
364484
except ImportError:
365485
pass
366486

487+
try:
488+
from transformers import ( # noqa: F401 # type: ignore
489+
FlaxPreTrainedModel,
490+
GenerationMixin,
491+
PreTrainedModel,
492+
TFPreTrainedModel,
493+
)
494+
495+
api_self = getattr(llm_api, "__self__", None)
496+
497+
if (
498+
isinstance(api_self, PreTrainedModel)
499+
or isinstance(api_self, TFPreTrainedModel)
500+
or isinstance(api_self, FlaxPreTrainedModel)
501+
):
502+
if (
503+
hasattr(llm_api, "__func__")
504+
and llm_api.__func__ == GenerationMixin.generate
505+
):
506+
return HuggingFaceModelCallable(*args, model_generate=llm_api, **kwargs)
507+
raise ValueError("Only text generation models are supported at this time.")
508+
except ImportError:
509+
pass
510+
try:
511+
from transformers import Pipeline # noqa: F401 # type: ignore
512+
513+
if isinstance(llm_api, Pipeline):
514+
# Couldn't find a constant for this
515+
if llm_api.task == "text-generation":
516+
return HuggingFacePipelineCallable(*args, pipeline=llm_api, **kwargs)
517+
raise ValueError(
518+
"Only text generation pipelines are supported at this time."
519+
)
520+
except ImportError:
521+
pass
522+
367523
# Let the user pass in an arbitrary callable.
368524
return ArbitraryCallable(*args, llm_api=llm_api, **kwargs)
369525

0 commit comments

Comments
 (0)