33using Microsoft . ML . OnnxRuntime . Tensors ;
44using OnnxStack . Core ;
55using OnnxStack . Core . Config ;
6+ using OnnxStack . Core . Image ;
67using OnnxStack . Core . Model ;
78using OnnxStack . Core . Services ;
89using OnnxStack . StableDiffusion . Common ;
@@ -113,15 +114,38 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelS
113114 // Process prompts
114115 var promptEmbeddings = await _promptService . CreatePromptAsync ( modelOptions , promptOptions , performGuidance ) ;
115116
117+ // If video input, process frames
118+ if ( promptOptions . HasInputVideo )
119+ {
120+ var frameIndex = 0 ;
121+ DenseTensor < float > videoTensor = null ;
122+ var videoFrames = promptOptions . InputVideo . VideoFrames . Frames ;
123+ var schedulerFrameCallback = CreateBatchCallback ( progressCallback , videoFrames . Count , ( ) => frameIndex ) ;
124+ foreach ( var videoFrame in videoFrames )
125+ {
126+ frameIndex ++ ;
127+ promptOptions . InputImage = new InputImage ( videoFrame ) ;
128+ var frameResultTensor = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , schedulerFrameCallback , cancellationToken ) ;
129+
130+ // Frame Progress
131+ ReportBatchProgress ( progressCallback , frameIndex , videoFrames . Count , frameResultTensor ) ;
132+
133+ // Concatenate frame
134+ videoTensor = videoTensor . Concatenate ( frameResultTensor ) ;
135+ }
136+
137+ _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
138+ return videoTensor ;
139+ }
140+
116141 // Run Scheduler steps
117142 var schedulerResult = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
118-
119143 _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
120-
121144 return schedulerResult ;
122145 }
123146
124147
148+
125149 /// <summary>
126150 /// Runs the stable diffusion batch loop
127151 /// </summary>
@@ -152,15 +176,11 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffu
152176 var batchSchedulerOptions = BatchGenerator . GenerateBatch ( modelOptions , batchOptions , schedulerOptions ) ;
153177
154178 var batchIndex = 1 ;
155- var schedulerCallback = ( DiffusionProgress progress ) => progressCallback ? . Invoke ( new DiffusionProgress ( batchIndex , batchSchedulerOptions . Count , progress . ProgressTensor )
156- {
157- SubProgressMax = progress . ProgressMax ,
158- SubProgressValue = progress . ProgressValue ,
159- } ) ;
179+ var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
160180 foreach ( var batchSchedulerOption in batchSchedulerOptions )
161181 {
162182 var diffuseTime = _logger ? . LogBegin ( "Diffuse starting..." ) ;
163- yield return new BatchResult ( batchSchedulerOption , await SchedulerStepAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , schedulerCallback , cancellationToken ) ) ;
183+ yield return new BatchResult ( batchSchedulerOption , await SchedulerStepAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) ) ;
164184 _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
165185 batchIndex ++ ;
166186 }
@@ -264,9 +284,14 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name
264284 /// <param name="progress">The progress.</param>
265285 /// <param name="progressMax">The progress maximum.</param>
266286 /// <param name="output">The output.</param>
267- protected void ReportProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > output )
287+ protected void ReportProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > progressTensor )
268288 {
269- progressCallback ? . Invoke ( new DiffusionProgress ( progress , progressMax , output ) ) ;
289+ progressCallback ? . Invoke ( new DiffusionProgress
290+ {
291+ StepMax = progressMax ,
292+ StepValue = progress ,
293+ StepTensor = progressTensor
294+ } ) ;
270295 }
271296
272297
@@ -279,13 +304,31 @@ protected void ReportProgress(Action<DiffusionProgress> progressCallback, int pr
279304 /// <param name="subProgress">The sub progress.</param>
280305 /// <param name="subProgressMax">The sub progress maximum.</param>
281306 /// <param name="output">The output.</param>
282- protected void ReportProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , int subProgress , int subProgressMax , DenseTensor < float > output )
307+ protected void ReportBatchProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > progressTensor )
308+ {
309+ progressCallback ? . Invoke ( new DiffusionProgress
310+ {
311+ BatchMax = progressMax ,
312+ BatchValue = progress ,
313+ BatchTensor = progressTensor
314+ } ) ;
315+ }
316+
317+
318+ private static Action < DiffusionProgress > CreateBatchCallback ( Action < DiffusionProgress > progressCallback , int batchCount , Func < int > batchIndex )
283319 {
284- progressCallback ? . Invoke ( new DiffusionProgress ( progress , progressMax , output )
320+ if ( progressCallback == null )
321+ return progressCallback ;
322+
323+ return ( DiffusionProgress progress ) => progressCallback ? . Invoke ( new DiffusionProgress
285324 {
286- SubProgressMax = subProgressMax ,
287- SubProgressValue = subProgress ,
325+ StepMax = progress . StepMax ,
326+ StepValue = progress . StepValue ,
327+ StepTensor = progress . StepTensor ,
328+ BatchMax = batchCount ,
329+ BatchValue = batchIndex ( )
288330 } ) ;
289331 }
332+
290333 }
291334}
0 commit comments