Skip to content

Commit 83c17fd

Browse files
committed
add from_onnx for image models
1 parent b740d87 commit 83c17fd

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

tools/who_what_benchmark/whowhatbench/model_loaders.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,10 @@ def load_text2image_model(
272272
if 'adapters' in kwargs and kwargs['adapters'] is not None:
273273
raise ValueError("Adapters are not supported for OVPipelineForText2Image.")
274274

275+
model_kwargs = {"from_onnx": kwargs.get('from_onnx', False), "safety_checker": None}
275276
try:
276277
model = TEXT2IMAGEPipeline.from_pretrained(
277-
model_id, device=device, ov_config=ov_config, safety_checker=None,
278+
model_id, device=device, ov_config=ov_config, **model_kwargs
278279
)
279280
except ValueError:
280281
model = TEXT2IMAGEPipeline.from_pretrained(
@@ -283,7 +284,7 @@ def load_text2image_model(
283284
use_cache=True,
284285
device=device,
285286
ov_config=ov_config,
286-
safety_checker=None,
287+
**model_kwargs
287288
)
288289

289290
return model
@@ -412,7 +413,7 @@ def load_image2image_genai_pipeline(model_dir, device="CPU", ov_config=None):
412413

413414

414415
def load_imagetext2image_model(
415-
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False
416+
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, **kwargs
416417
):
417418
if use_hf:
418419
from diffusers import AutoPipelineForImage2Image
@@ -426,9 +427,10 @@ def load_imagetext2image_model(
426427
else:
427428
logger.info("Using Optimum API")
428429
from optimum.intel.openvino import OVPipelineForImage2Image
430+
model_kwargs = {"from_onnx": kwargs.get('from_onnx', False), "safety_checker": None}
429431
try:
430432
model = OVPipelineForImage2Image.from_pretrained(
431-
model_id, device=device, ov_config=ov_config, safety_checker=None,
433+
model_id, device=device, ov_config=ov_config, **model_kwargs
432434
)
433435
except ValueError:
434436
model = OVPipelineForImage2Image.from_pretrained(
@@ -437,7 +439,7 @@ def load_imagetext2image_model(
437439
use_cache=True,
438440
device=device,
439441
ov_config=ov_config,
440-
safety_checker=None,
442+
**model_kwargs
441443
)
442444
return model
443445

@@ -457,7 +459,7 @@ def load_inpainting_genai_pipeline(model_dir, device="CPU", ov_config=None):
457459

458460

459461
def load_inpainting_model(
460-
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False
462+
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, **kwargs
461463
):
462464
if use_hf:
463465
from diffusers import AutoPipelineForInpainting
@@ -471,9 +473,10 @@ def load_inpainting_model(
471473
else:
472474
logger.info("Using Optimum API")
473475
from optimum.intel.openvino import OVPipelineForInpainting
476+
model_kwargs = {"from_onnx": kwargs.get('from_onnx', False), "safety_checker": None}
474477
try:
475478
model = OVPipelineForInpainting.from_pretrained(
476-
model_id, device=device, ov_config=ov_config, safety_checker=None,
479+
model_id, device=device, ov_config=ov_config, **model_kwargs
477480
)
478481
except ValueError as e:
479482
logger.error("Failed to load inpaiting pipeline. Details:\n", e)
@@ -483,7 +486,7 @@ def load_inpainting_model(
483486
use_cache=True,
484487
device=device,
485488
ov_config=ov_config,
486-
safety_checker=None,
489+
**model_kwargs
487490
)
488491
return model
489492

@@ -631,9 +634,9 @@ def load_model(
631634
elif model_type == "visual-text":
632635
return load_visual_text_model(model_id, device, ov_options, use_hf, use_genai, **kwargs)
633636
elif model_type == "image-to-image":
634-
return load_imagetext2image_model(model_id, device, ov_options, use_hf, use_genai)
637+
return load_imagetext2image_model(model_id, device, ov_options, use_hf, use_genai, **kwargs)
635638
elif model_type == "image-inpainting":
636-
return load_inpainting_model(model_id, device, ov_options, use_hf, use_genai)
639+
return load_inpainting_model(model_id, device, ov_options, use_hf, use_genai, **kwargs)
637640
elif model_type == "text-embedding":
638641
return load_embedding_model(model_id, device, ov_options, use_hf, use_genai, **kwargs)
639642
elif model_type == "text-reranking":

0 commit comments

Comments
 (0)