11using Microsoft . Extensions . Logging ;
2- using Microsoft . ML . OnnxRuntime . Tensors ;
32using OnnxStack . Core ;
43using OnnxStack . Core . Config ;
54using OnnxStack . Core . Image ;
6- using OnnxStack . Core . Services ;
5+ using OnnxStack . Core . Video ;
76using OnnxStack . StableDiffusion . Common ;
87using OnnxStack . StableDiffusion . Config ;
98using OnnxStack . StableDiffusion . Enums ;
109using OnnxStack . StableDiffusion . Models ;
1110using OnnxStack . StableDiffusion . Pipelines ;
1211using OnnxStack . UI . Models ;
13- using SixLabors . ImageSharp ;
1412using SixLabors . ImageSharp . PixelFormats ;
1513using System ;
1614using System . Collections . Concurrent ;
1715using System . Collections . Generic ;
18- using System . IO ;
19- using System . Runtime . CompilerServices ;
2016using System . Threading ;
2117using System . Threading . Tasks ;
2218
@@ -28,7 +24,6 @@ namespace OnnxStack.UI.Services
2824 /// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
2925 public sealed class StableDiffusionService : IStableDiffusionService
3026 {
31- private readonly IVideoService _videoService ;
3227 private readonly ILogger < StableDiffusionService > _logger ;
3328 private readonly OnnxStackUIConfig _configuration ;
3429 private readonly Dictionary < IOnnxModel , IPipeline > _pipelines ;
@@ -38,11 +33,10 @@ public sealed class StableDiffusionService : IStableDiffusionService
3833 /// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
3934 /// </summary>
4035 /// <param name="schedulerService">The scheduler service.</param>
41- public StableDiffusionService ( OnnxStackUIConfig configuration , IVideoService videoService , ILogger < StableDiffusionService > logger )
36+ public StableDiffusionService ( OnnxStackUIConfig configuration , ILogger < StableDiffusionService > logger )
4237 {
4338 _logger = logger ;
4439 _configuration = configuration ;
45- _videoService = videoService ;
4640 _pipelines = new Dictionary < IOnnxModel , IPipeline > ( ) ;
4741 _controlNetSessions = new ConcurrentDictionary < IOnnxModel , ControlNetModel > ( ) ;
4842 }
@@ -64,8 +58,6 @@ public async Task<bool> LoadModelAsync(StableDiffusionModelSet model)
6458 }
6559
6660
67-
68-
6961 /// <summary>
7062 /// Unloads the model.
7163 /// </summary>
@@ -95,6 +87,11 @@ public bool IsModelLoaded(StableDiffusionModelSet modelOptions)
9587 }
9688
9789
90+ /// <summary>
91+ /// Loads the model.
92+ /// </summary>
93+ /// <param name="model"></param>
94+ /// <returns></returns>
9895 public async Task < bool > LoadControlNetModelAsync ( ControlNetModelSet model )
9996 {
10097 if ( _controlNetSessions . ContainsKey ( model ) )
@@ -106,6 +103,12 @@ public async Task<bool> LoadControlNetModelAsync(ControlNetModelSet model)
106103 return _controlNetSessions . TryAdd ( model , controlNet ) ;
107104 }
108105
106+
107+ /// <summary>
108+ /// Unloads the model.
109+ /// </summary>
110+ /// <param name="model"></param>
111+ /// <returns></returns>
109112 public Task < bool > UnloadControlNetModelAsync ( ControlNetModelSet model )
110113 {
111114 if ( _controlNetSessions . Remove ( model , out var controlNet ) )
@@ -115,6 +118,14 @@ public Task<bool> UnloadControlNetModelAsync(ControlNetModelSet model)
115118 return Task . FromResult ( true ) ;
116119 }
117120
121+
122+ /// <summary>
123+ /// Determines whether the specified model is loaded
124+ /// </summary>
125+ /// <param name="modelOptions">The model options.</param>
126+ /// <returns>
127+ /// <c>true</c> if the specified model is loaded; otherwise, <c>false</c>.
128+ /// </returns>
118129 public bool IsControlNetModelLoaded ( ControlNetModelSet modelOptions )
119130 {
120131 return _controlNetSessions . ContainsKey ( modelOptions ) ;
@@ -129,164 +140,55 @@ public bool IsControlNetModelLoaded(ControlNetModelSet modelOptions)
129140 /// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
130141 /// <param name="cancellationToken">The cancellation token.</param>
131142 /// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
132- public async Task < OnnxImage > GenerateAsync ( ModelOptions model , PromptOptions prompt , SchedulerOptions options , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
133- {
134- return await DiffuseAsync ( model , prompt , options , progressCallback , cancellationToken )
135- . ContinueWith ( t => new OnnxImage ( t . Result ) , cancellationToken )
136- . ConfigureAwait ( false ) ;
137- }
138-
139-
140-
141-
142-
143- /// <summary>
144- /// Generates a batch of StableDiffusion image using the prompt and options provided.
145- /// </summary>
146- /// <param name="modelOptions">The model options.</param>
147- /// <param name="promptOptions">The prompt options.</param>
148- /// <param name="schedulerOptions">The scheduler options.</param>
149- /// <param name="batchOptions">The batch options.</param>
150- /// <param name="progressCallback">The progress callback.</param>
151- /// <param name="cancellationToken">The cancellation token.</param>
152- /// <returns></returns>
153- public IAsyncEnumerable < BatchResult > GenerateBatchAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
154- {
155- return DiffuseBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) ;
156- }
157-
158-
159-
160-
161-
162-
163-
164-
165-
166- /// <summary>
167- /// Runs the diffusion process
168- /// </summary>
169- /// <param name="modelOptions">The model options.</param>
170- /// <param name="promptOptions">The prompt options.</param>
171- /// <param name="schedulerOptions">The scheduler options.</param>
172- /// <param name="progress">The progress.</param>
173- /// <param name="cancellationToken">The cancellation token.</param>
174- /// <returns></returns>
175- /// <exception cref="System.Exception">
176- /// Pipeline not found or is unsupported
177- /// or
178- /// Diffuser not found or is unsupported
179- /// or
180- /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline.
181- /// </exception>
182- private async Task < DenseTensor < float > > DiffuseAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , Action < DiffusionProgress > progress = null , CancellationToken cancellationToken = default )
143+ public async Task < OnnxImage > GenerateImageAsync ( ModelOptions model , PromptOptions prompt , SchedulerOptions options , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
183144 {
184- if ( ! _pipelines . TryGetValue ( modelOptions . BaseModel , out var pipeline ) )
145+ if ( ! _pipelines . TryGetValue ( model . BaseModel , out var pipeline ) )
185146 throw new Exception ( "Pipeline not found or is unsupported" ) ;
186147
187148 var controlNet = default ( ControlNetModel ) ;
188- if ( modelOptions . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( modelOptions . ControlNetModel , out controlNet ) )
149+ if ( model . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( model . ControlNetModel , out controlNet ) )
189150 throw new Exception ( "ControlNet not loaded" ) ;
190151
191- pipeline . ValidateInputs ( promptOptions , schedulerOptions ) ;
152+ pipeline . ValidateInputs ( prompt , options ) ;
192153
193- await GenerateInputVideoFrames ( promptOptions , progress ) ;
194- return await pipeline . RunAsync ( promptOptions , schedulerOptions , controlNet , progress , cancellationToken ) ;
154+ return await pipeline . GenerateImageAsync ( prompt , options , controlNet , progressCallback , cancellationToken ) ;
195155 }
196156
197157
198158 /// <summary>
199- /// Runs the batch diffusion process .
159+ /// Generates the StableDiffusion video using the prompt and options provided .
200160 /// </summary>
201- /// <param name="modelOptions">The model options.</param>
202- /// <param name="promptOptions">The prompt options.</param>
203- /// <param name="schedulerOptions">The scheduler options.</param>
204- /// <param name="batchOptions">The batch options.</param>
205- /// <param name="progress">The progress.</param>
161+ /// <param name="model">The model.</param>
162+ /// <param name="prompt">The prompt.</param>
163+ /// <param name="options">The options.</param>
164+ /// <param name="progressCallback">The progress callback.</param>
206165 /// <param name="cancellationToken">The cancellation token.</param>
207166 /// <returns></returns>
208167 /// <exception cref="System.Exception">
209168 /// Pipeline not found or is unsupported
210169 /// or
211- /// Diffuser not found or is unsupported
212- /// or
213- /// Scheduler '{schedulerOptions.SchedulerType}' is not compatible with the `{pipeline.PipelineType}` pipeline.
170+ /// ControlNet not loaded
214171 /// </exception>
215- private async IAsyncEnumerable < BatchResult > DiffuseBatchAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
172+ public async Task < OnnxVideo > GenerateVideoAsync ( ModelOptions model , PromptOptions prompt , SchedulerOptions options , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
216173 {
217- if ( ! _pipelines . TryGetValue ( modelOptions . BaseModel , out var pipeline ) )
174+ if ( ! _pipelines . TryGetValue ( model . BaseModel , out var pipeline ) )
218175 throw new Exception ( "Pipeline not found or is unsupported" ) ;
219176
220177 var controlNet = default ( ControlNetModel ) ;
221- if ( modelOptions . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( modelOptions . ControlNetModel , out controlNet ) )
178+ if ( model . ControlNetModel is not null && ! _controlNetSessions . TryGetValue ( model . ControlNetModel , out controlNet ) )
222179 throw new Exception ( "ControlNet not loaded" ) ;
223180
224- pipeline . ValidateInputs ( promptOptions , schedulerOptions ) ;
181+ pipeline . ValidateInputs ( prompt , options ) ;
225182
226- await GenerateInputVideoFrames ( promptOptions , progressCallback ) ;
227- await foreach ( var result in pipeline . RunBatchAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
228- {
229- yield return result ;
230- }
183+ return await pipeline . GenerateVideoAsync ( prompt , options , controlNet , progressCallback , cancellationToken ) ;
231184 }
232185
233186
234187 /// <summary>
235- /// Generates the video result as bytes .
188+ /// Creates the pipeline .
236189 /// </summary>
237- /// <param name="options">The options.</param>
238- /// <param name="videoTensor">The video tensor.</param>
239- /// <param name="progress">The progress.</param>
240- /// <param name="cancellationToken">The cancellation token.</param>
241- /// <returns></returns>
242- private async Task < byte [ ] > GenerateVideoResultAsBytesAsync ( DenseTensor < float > videoTensor , float videoFPS , Action < DiffusionProgress > progress = null , CancellationToken cancellationToken = default )
243- {
244- progress ? . Invoke ( new DiffusionProgress ( "Generating Video Result..." ) ) ;
245- var videoResult = await _videoService . CreateVideoAsync ( videoTensor , videoFPS , cancellationToken ) ;
246- return videoResult . Data ;
247- }
248-
249-
250- /// <summary>
251- /// Generates the video result as stream.
252- /// </summary>
253- /// <param name="options">The options.</param>
254- /// <param name="videoTensor">The video tensor.</param>
255- /// <param name="progress">The progress.</param>
256- /// <param name="cancellationToken">The cancellation token.</param>
190+ /// <param name="model">The model.</param>
257191 /// <returns></returns>
258- private async Task < MemoryStream > GenerateVideoResultAsStreamAsync ( DenseTensor < float > videoTensor , float videoFPS , Action < DiffusionProgress > progress = null , CancellationToken cancellationToken = default )
259- {
260- return new MemoryStream ( await GenerateVideoResultAsBytesAsync ( videoTensor , videoFPS , progress , cancellationToken ) ) ;
261- }
262-
263-
264- /// <summary>
265- /// Generates the input video frames.
266- /// </summary>
267- /// <param name="promptOptions">The prompt options.</param>
268- /// <param name="progress">The progress.</param>
269- private async Task GenerateInputVideoFrames ( PromptOptions promptOptions , Action < DiffusionProgress > progress )
270- {
271- if ( ! promptOptions . HasInputVideo || promptOptions . InputVideo . VideoFrames is not null )
272- return ;
273-
274- if ( promptOptions . VideoInputFPS == 0 || promptOptions . VideoOutputFPS == 0 )
275- {
276- var videoInfo = await _videoService . GetVideoInfoAsync ( promptOptions . InputVideo ) ;
277- if ( promptOptions . VideoInputFPS == 0 )
278- promptOptions . VideoInputFPS = videoInfo . FPS ;
279-
280- if ( promptOptions . VideoOutputFPS == 0 )
281- promptOptions . VideoOutputFPS = videoInfo . FPS ;
282- }
283-
284- var videoFrame = await _videoService . CreateFramesAsync ( promptOptions . InputVideo , promptOptions . VideoInputFPS ) ;
285- progress ? . Invoke ( new DiffusionProgress ( $ "Generating video frames @ { promptOptions . VideoInputFPS } fps") ) ;
286- promptOptions . InputVideo . VideoFrames = videoFrame ;
287- }
288-
289-
290192 private IPipeline CreatePipeline ( StableDiffusionModelSet model )
291193 {
292194 return model . PipelineType switch
0 commit comments