1515
1616import random
1717import tempfile
18- import unittest
1918
2019import numpy as np
2120import PIL
3433from ..test_modular_pipelines_common import ModularPipelineTesterMixin
3534
3635
37- class FluxModularTests :
36+ class TestFluxModularPipelineFast ( ModularPipelineTesterMixin ) :
3837 pipeline_class = FluxModularPipeline
3938 pipeline_blocks_class = FluxAutoBlocks
4039 repo = "hf-internal-testing/tiny-flux-modular"
4140
42- def get_pipeline (self , components_manager = None , torch_dtype = torch .float32 ):
43- pipeline = self .pipeline_blocks_class ().init_pipeline (self .repo , components_manager = components_manager )
44- pipeline .load_components (torch_dtype = torch_dtype )
45- return pipeline
41+ params = frozenset (["prompt" , "height" , "width" , "guidance_scale" ])
42+ batch_params = frozenset (["prompt" ])
4643
47- def get_dummy_inputs (self , device , seed = 0 ):
48- if str (device ).startswith ("mps" ):
49- generator = torch .manual_seed (seed )
50- else :
51- generator = torch .Generator (device = device ).manual_seed (seed )
44+ def get_dummy_inputs (self , seed = 0 ):
45+ generator = self .get_generator (seed )
5246 inputs = {
5347 "prompt" : "A painting of a squirrel eating a burger" ,
5448 "generator" : generator ,
@@ -57,36 +51,47 @@ def get_dummy_inputs(self, device, seed=0):
5751 "height" : 8 ,
5852 "width" : 8 ,
5953 "max_sequence_length" : 48 ,
60- "output_type" : "np " ,
54+ "output_type" : "pt " ,
6155 }
6256 return inputs
6357
6458
65- class FluxModularPipelineFastTests ( FluxModularTests , ModularPipelineTesterMixin , unittest . TestCase ):
66- params = frozenset ([ "prompt" , "height" , "width" , "guidance_scale" ])
67- batch_params = frozenset ([ "prompt" ])
68-
59+ class TestFluxImg2ImgModularPipelineFast ( ModularPipelineTesterMixin ):
60+ pipeline_class = FluxModularPipeline
61+ pipeline_blocks_class = FluxAutoBlocks
62+ repo = "hf-internal-testing/tiny-flux-modular"
6963
70- class FluxImg2ImgModularPipelineFastTests (FluxModularTests , ModularPipelineTesterMixin , unittest .TestCase ):
7164 params = frozenset (["prompt" , "height" , "width" , "guidance_scale" , "image" ])
7265 batch_params = frozenset (["prompt" , "image" ])
7366
7467 def get_pipeline (self , components_manager = None , torch_dtype = torch .float32 ):
7568 pipeline = super ().get_pipeline (components_manager , torch_dtype )
69+
7670 # Override `vae_scale_factor` here as currently, `image_processor` is initialized with
7771 # fixed constants instead of
7872 # https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
7973 pipeline .image_processor = VaeImageProcessor (vae_scale_factor = 2 )
8074 return pipeline
8175
82- def get_dummy_inputs (self , device , seed = 0 ):
83- inputs = super ().get_dummy_inputs (device , seed )
84- image = floats_tensor ((1 , 3 , 32 , 32 ), rng = random .Random (seed )).to (device )
85- image = image / 2 + 0.5
86- inputs ["image" ] = image
87- inputs ["strength" ] = 0.8
88- inputs ["height" ] = 8
89- inputs ["width" ] = 8
76+ def get_dummy_inputs (self , seed = 0 ):
77+ generator = self .get_generator (seed )
78+ inputs = {
79+ "prompt" : "A painting of a squirrel eating a burger" ,
80+ "generator" : generator ,
81+ "num_inference_steps" : 4 ,
82+ "guidance_scale" : 5.0 ,
83+ "height" : 8 ,
84+ "width" : 8 ,
85+ "max_sequence_length" : 48 ,
86+ "output_type" : "pt" ,
87+ }
88+ image = floats_tensor ((1 , 3 , 32 , 32 ), rng = random .Random (seed )).to (torch_device )
89+ image = image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
90+ init_image = PIL .Image .fromarray (np .uint8 (image )).convert ("RGB" )
91+
92+ inputs ["image" ] = init_image
93+ inputs ["strength" ] = 0.5
94+
9095 return inputs
9196
9297 def test_save_from_pretrained (self ):
@@ -96,6 +101,7 @@ def test_save_from_pretrained(self):
96101
97102 with tempfile .TemporaryDirectory () as tmpdirname :
98103 base_pipe .save_pretrained (tmpdirname )
104+
99105 pipe = ModularPipeline .from_pretrained (tmpdirname ).to (torch_device )
100106 pipe .load_components (torch_dtype = torch .float32 )
101107 pipe .to (torch_device )
@@ -105,26 +111,62 @@ def test_save_from_pretrained(self):
105111
106112 image_slices = []
107113 for pipe in pipes :
108- inputs = self .get_dummy_inputs (torch_device )
114+ inputs = self .get_dummy_inputs ()
109115 image = pipe (** inputs , output = "images" )
110116
111117 image_slices .append (image [0 , - 3 :, - 3 :, - 1 ].flatten ())
112118
113- assert np .abs (image_slices [0 ] - image_slices [1 ]).max () < 1e-3
119+ assert torch .abs (image_slices [0 ] - image_slices [1 ]).max () < 1e-3
114120
115121
116- class FluxKontextModularPipelineFastTests ( FluxImg2ImgModularPipelineFastTests ):
122+ class TestFluxKontextModularPipelineFast ( ModularPipelineTesterMixin ):
117123 pipeline_class = FluxKontextModularPipeline
118124 pipeline_blocks_class = FluxKontextAutoBlocks
119125 repo = "hf-internal-testing/tiny-flux-kontext-pipe"
120126
121- def get_dummy_inputs (self , device , seed = 0 ):
122- inputs = super ().get_dummy_inputs (device , seed )
127+ params = frozenset (["prompt" , "height" , "width" , "guidance_scale" , "image" ])
128+ batch_params = frozenset (["prompt" , "image" ])
129+
130+ def get_dummy_inputs (self , seed = 0 ):
131+ generator = self .get_generator (seed )
132+ inputs = {
133+ "prompt" : "A painting of a squirrel eating a burger" ,
134+ "generator" : generator ,
135+ "num_inference_steps" : 2 ,
136+ "guidance_scale" : 5.0 ,
137+ "height" : 8 ,
138+ "width" : 8 ,
139+ "max_sequence_length" : 48 ,
140+ "output_type" : "pt" ,
141+ }
123142 image = PIL .Image .new ("RGB" , (32 , 32 ), 0 )
124- _ = inputs . pop ( "strength" )
143+
125144 inputs ["image" ] = image
126- inputs ["height" ] = 8
127- inputs ["width" ] = 8
128- inputs ["max_area" ] = 8 * 8
145+ inputs ["max_area" ] = inputs ["height" ] * inputs ["width" ]
129146 inputs ["_auto_resize" ] = False
147+
130148 return inputs
149+
150+ def test_save_from_pretrained (self ):
151+ pipes = []
152+ base_pipe = self .get_pipeline ().to (torch_device )
153+ pipes .append (base_pipe )
154+
155+ with tempfile .TemporaryDirectory () as tmpdirname :
156+ base_pipe .save_pretrained (tmpdirname )
157+
158+ pipe = ModularPipeline .from_pretrained (tmpdirname ).to (torch_device )
159+ pipe .load_components (torch_dtype = torch .float32 )
160+ pipe .to (torch_device )
161+ pipe .image_processor = VaeImageProcessor (vae_scale_factor = 2 )
162+
163+ pipes .append (pipe )
164+
165+ image_slices = []
166+ for pipe in pipes :
167+ inputs = self .get_dummy_inputs ()
168+ image = pipe (** inputs , output = "images" )
169+
170+ image_slices .append (image [0 , - 3 :, - 3 :, - 1 ].flatten ())
171+
172+ assert torch .abs (image_slices [0 ] - image_slices [1 ]).max () < 1e-3
0 commit comments