Skip to content

Commit c2d0925

Browse files
committed
Video Interpolation pipeline
1 parent 071d741 commit c2d0925

File tree

19 files changed

+507
-64
lines changed

19 files changed

+507
-64
lines changed

TensorStack.Common/ModelSession.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,12 @@ protected virtual void Dispose(bool disposing)
237237

238238
if (disposing)
239239
{
240-
_allocator?.Dispose();
241-
_options?.Dispose();
242240
_session?.Dispose();
241+
_options?.Dispose();
242+
_allocator?.Dispose();
243243
_session = null;
244+
_options = null;
245+
_allocator = null;
244246
}
245247

246248
disposed = true;
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System.Collections.Generic;
2+
using System.Threading;
3+
4+
namespace TensorStack.Common.Video
5+
{
6+
public class VideoStream : IAsyncEnumerable<VideoFrame>
7+
{
8+
private readonly IAsyncEnumerable<VideoFrame> _stream;
9+
private readonly int _width;
10+
private readonly int _height;
11+
private readonly float _frameRate;
12+
private readonly int _frameCount;
13+
14+
public VideoStream(IAsyncEnumerable<VideoFrame> stream, int frameCount, float frameRate, int width, int height)
15+
{
16+
_stream = stream;
17+
_frameCount = frameCount;
18+
_frameRate = frameRate;
19+
_width = width;
20+
_height = height;
21+
}
22+
23+
public int Width => _width;
24+
public int Height => _height;
25+
public float FrameRate => _frameRate;
26+
public int FrameCount => _frameCount;
27+
28+
public IAsyncEnumerator<VideoFrame> GetAsyncEnumerator(CancellationToken cancellationToken = default)
29+
{
30+
return _stream.GetAsyncEnumerator(cancellationToken);
31+
}
32+
}
33+
}

TensorStack.Extractors/Common/BackgroundImageOptions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using TensorStack.Common;
43
using TensorStack.Common.Pipeline;
54
using TensorStack.Common.Tensor;
65

TensorStack.Extractors/Common/ExtractorImageOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ namespace TensorStack.Extractors.Common
66
{
77
public record ExtractorImageOptions : ExtractorOptions
88
{
9-
public ImageTensor Input { get; init; }
9+
public ImageTensor Image { get; init; }
1010
}
1111
}
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3-
using System.Collections.Generic;
4-
using TensorStack.Common;
53
using TensorStack.Common.Video;
64

75
namespace TensorStack.Extractors.Common
86
{
97
public record ExtractorStreamOptions : ExtractorOptions
108
{
11-
public IAsyncEnumerable<VideoFrame> Input { get; }
9+
public VideoStream Stream { get; }
1210
}
1311
}

TensorStack.Extractors/Common/ExtractorVideoOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ namespace TensorStack.Extractors.Common
77
{
88
public record ExtractorVideoOptions : ExtractorOptions
99
{
10-
public VideoTensor Input { get; }
10+
public VideoTensor Video { get; }
1111
}
1212
}

TensorStack.Extractors/Pipelines/ExtractorPipeline.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,16 @@ public async Task<ImageTensor> RunAsync(ExtractorImageOptions options, IProgress
6565
{
6666
var timestamp = RunProgress.GetTimestamp();
6767
if (_extractorModel.Normalization == Normalization.ZeroToOne)
68-
options.Input.NormalizeZeroToOne();
68+
options.Image.NormalizeZeroToOne();
6969

70-
var resultTensor = await ExtractInternalAsync(options.Input, options, cancellationToken);
70+
var resultTensor = await ExtractInternalAsync(options.Image, options, cancellationToken);
7171
NormalizeResult(resultTensor, options.IsInverted);
7272

7373
if (_extractorModel.Normalization == Normalization.ZeroToOne)
74-
options.Input.NormalizeOneToOne();
74+
options.Image.NormalizeOneToOne();
7575

7676
if (options.MergeInput)
77-
resultTensor = MergeResult(options.Input, resultTensor);
77+
resultTensor = MergeResult(options.Image, resultTensor);
7878

7979
progressCallback?.Report(new RunProgress(timestamp));
8080
return resultTensor;
@@ -93,10 +93,10 @@ public async Task<VideoTensor> RunAsync(ExtractorVideoOptions options, IProgress
9393
{
9494
var timestamp = RunProgress.GetTimestamp();
9595
if (_extractorModel.Normalization == Normalization.ZeroToOne)
96-
options.Input.NormalizeZeroToOne();
96+
options.Video.NormalizeZeroToOne();
9797

9898
var results = new List<ImageTensor>();
99-
foreach (var frame in options.Input.GetFrames())
99+
foreach (var frame in options.Video.GetFrames())
100100
{
101101
var frameTime = Stopwatch.GetTimestamp();
102102
var resultTensor = await ExtractInternalAsync(frame, options, cancellationToken);
@@ -107,10 +107,10 @@ public async Task<VideoTensor> RunAsync(ExtractorVideoOptions options, IProgress
107107
resultTensor = MergeResult(frame, resultTensor);
108108

109109
results.Add(resultTensor);
110-
progressCallback?.Report(new RunProgress(results.Count, options.Input.Frames, frameTime));
110+
progressCallback?.Report(new RunProgress(results.Count, options.Video.Frames, frameTime));
111111
}
112112

113-
var resultVideoTensor = new VideoTensor(results.Join(), options.Input.FrameRate);
113+
var resultVideoTensor = new VideoTensor(results.Join(), options.Video.FrameRate);
114114
progressCallback?.Report(new RunProgress(timestamp));
115115
return resultVideoTensor;
116116
}
@@ -129,7 +129,7 @@ public async IAsyncEnumerable<VideoFrame> RunAsync(ExtractorStreamOptions option
129129
{
130130
var frameCount = 0;
131131
var timestamp = RunProgress.GetTimestamp();
132-
await foreach (var videoFrame in options.Input)
132+
await foreach (var videoFrame in options.Stream)
133133
{
134134
var frameTime = Stopwatch.GetTimestamp();
135135
if (_extractorModel.Normalization == Normalization.ZeroToOne)

TensorStack.Upscaler/Common/UpscaleImageOptions.cs

Lines changed: 0 additions & 14 deletions
This file was deleted.

TensorStack.Upscaler/Common/UpscaleOptions.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// Licensed under the Apache 2.0 License.
33
using TensorStack.Common;
44
using TensorStack.Common.Pipeline;
5+
using TensorStack.Common.Tensor;
6+
using TensorStack.Common.Video;
57

68
namespace TensorStack.Upscaler.Common
79
{
@@ -25,4 +27,43 @@ public abstract record UpscaleOptions : IRunOptions
2527
/// </summary>
2628
public int TileOverlap { get; init; }
2729
}
30+
31+
32+
33+
/// <summary>
34+
/// Image UpscaleOptions.
35+
/// </summary>
36+
public sealed record UpscaleImageOptions : UpscaleOptions
37+
{
38+
/// <summary>
39+
/// Gets the image input.
40+
/// </summary>
41+
public ImageTensor Image { get; init; }
42+
}
43+
44+
45+
46+
/// <summary>
47+
/// Video UpscaleOptions.
48+
/// </summary>
49+
public sealed record UpscaleVideoOptions : UpscaleOptions
50+
{
51+
/// <summary>
52+
/// Gets the video input.
53+
/// </summary>
54+
public VideoTensor Video { get; init; }
55+
}
56+
57+
58+
59+
/// <summary>
60+
/// Stream UpscaleOptions.
61+
/// </summary>
62+
public sealed record UpscaleStreamOptions : UpscaleOptions
63+
{
64+
/// <summary>
65+
/// Gets the stream input.
66+
/// </summary>
67+
public VideoStream Stream { get; init; }
68+
}
2869
}

TensorStack.Upscaler/Common/UpscaleStreamOptions.cs

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)