11using Microsoft . ML . OnnxRuntime ;
2- using OnnxStack . Core ;
32using OnnxStack . StableDiffusion . Config ;
3+ using OnnxStack . StableDiffusion . Enums ;
44using System ;
5- using System . Collections . Generic ;
65using System . Linq ;
76
87namespace OnnxStack . StableDiffusion
98{
10- internal static class Extensions
9+ public static class Extensions
1110 {
1211 /// <summary>
1312 /// Gets the first element and casts it to the specified type.
1413 /// </summary>
1514 /// <typeparam name="T">Desired return type</typeparam>
1615 /// <param name="collection">The collection.</param>
1716 /// <returns>Firts element in the collection cast as <see cref="T"/></returns>
18- public static T FirstElementAs < T > ( this IDisposableReadOnlyCollection < DisposableNamedOnnxValue > collection )
17+ internal static T FirstElementAs < T > ( this IDisposableReadOnlyCollection < DisposableNamedOnnxValue > collection )
1918 {
2019 if ( collection is null || collection . Count == 0 )
2120 return default ;
@@ -34,7 +33,7 @@ public static T FirstElementAs<T>(this IDisposableReadOnlyCollection<DisposableN
3433 /// <typeparam name="T">Desired return type</typeparam>
3534 /// <param name="collection">The collection.</param>
3635 /// <returns>Last element in the collection cast as <see cref="T"/></returns>
37- public static T LastElementAs < T > ( this IDisposableReadOnlyCollection < DisposableNamedOnnxValue > collection )
36+ internal static T LastElementAs < T > ( this IDisposableReadOnlyCollection < DisposableNamedOnnxValue > collection )
3837 {
3938 if ( collection is null || collection . Count == 0 )
4039 return default ;
@@ -53,7 +52,7 @@ public static T LastElementAs<T>(this IDisposableReadOnlyCollection<DisposableNa
5352 /// <param name="options">The options.</param>
5453 /// <returns></returns>
5554 /// <exception cref="System.ArgumentOutOfRangeException">Width must be divisible by 64</exception>
56- public static int GetScaledWidth ( this SchedulerOptions options )
55+ internal static int GetScaledWidth ( this SchedulerOptions options )
5756 {
5857 if ( options . Width % 64 > 0 )
5958 throw new ArgumentOutOfRangeException ( nameof ( options . Width ) , $ "{ nameof ( options . Width ) } must be divisible by 64") ;
@@ -68,7 +67,7 @@ public static int GetScaledWidth(this SchedulerOptions options)
6867 /// <param name="options">The options.</param>
6968 /// <returns></returns>
7069 /// <exception cref="System.ArgumentOutOfRangeException">Height must be divisible by 64</exception>
71- public static int GetScaledHeight ( this SchedulerOptions options )
70+ internal static int GetScaledHeight ( this SchedulerOptions options )
7271 {
7372 if ( options . Height % 64 > 0 )
7473 throw new ArgumentOutOfRangeException ( nameof ( options . Height ) , $ "{ nameof ( options . Height ) } must be divisible by 64") ;
@@ -84,9 +83,36 @@ public static int GetScaledHeight(this SchedulerOptions options)
8483 /// <param name="batch">The batch.</param>
8584 /// <param name="channels">The channels.</param>
8685 /// <returns>Tensor dimension of [batch, channels, (Height / 8), (Width / 8)]</returns>
87- public static int [ ] GetScaledDimension ( this SchedulerOptions options , int batch = 1 , int channels = 4 )
86+ internal static int [ ] GetScaledDimension ( this SchedulerOptions options , int batch = 1 , int channels = 4 )
8887 {
8988 return new [ ] { batch , channels , options . GetScaledHeight ( ) , options . GetScaledWidth ( ) } ;
9089 }
90+
91+
92+ /// <summary>
93+ /// Gets the pipeline schedulers.
94+ /// </summary>
95+ /// <param name="pipelineType">Type of the pipeline.</param>
96+ /// <returns></returns>
97+ public static SchedulerType [ ] GetSchedulerTypes ( this DiffuserPipelineType pipelineType )
98+ {
99+ return pipelineType switch
100+ {
101+ DiffuserPipelineType . StableDiffusion => new [ ]
102+ {
103+ SchedulerType . LMS ,
104+ SchedulerType . Euler ,
105+ SchedulerType . EulerAncestral ,
106+ SchedulerType . DDPM ,
107+ SchedulerType . DDIM ,
108+ SchedulerType . KDPM2
109+ } ,
110+ DiffuserPipelineType . LatentConsistency => new [ ]
111+ {
112+ SchedulerType . LCM
113+ } ,
114+ _ => default
115+ } ;
116+ }
91117 }
92118}
0 commit comments