-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The official usage of stabilityai/stable-diffusion-xl-refiner-1.0 :
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, image=init_image).imagesbreaks, error message:
Traceback (most recent call last):
File "/home/jiqing/HuggingFace/tests/workloads/test_sd.py", line 13, in <module>
image = pipe(prompt, image=init_image).images
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py", line 1483, in __c
all__
image = self.vae.decode(latents, return_dict=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 294, in decode
decoded = self._decode(z).sample
^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 265, in _decode
dec = self.decoder(z)
^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/models/autoencoders/vae.py", line 302, in forward
sample = up_block(sample, latent_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 2639, in forward
hidden_states = resnet(hidden_states, temb=temb)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jiqing/diffusers/src/diffusers/models/resnet.py", line 341, in forward
hidden_states = self.conv1(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 553, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 548, in _conv_forward
return F.conv2d(
^^^^^^^^^
RuntimeError: Input type (c10::Half) and bias type (float) should be the same
The regression PR is #12512 . Should we revert this change? @sayakpaul
Reproduction
As before
Logs
System Info
latest version
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working