@@ -81,6 +81,70 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
8181 dst[index] = result;
8282}
8383
84+ namespace bicubic_interpolation {
85+ // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
86+ __device__ const float a = -0 .75f ; // use alpha = -0.75 (same as PyTorch)
87+
88+ static __device__ float weight1 (float x) { return ((a + 2 ) * x - (a + 3 )) * x * x + 1 ; };
89+ static __device__ float weight2 (float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
90+
91+ static __device__ float bicubic (float p0, float p1, float p2, float p3, float x) {
92+ const float w0 = weight2 (x + 1 );
93+ const float w1 = weight1 (x + 0 );
94+ const float w2 = weight1 (1 - x);
95+ const float w3 = weight2 (2 - x);
96+ return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
97+ };
98+ } // namespace bicubic_interpolation
99+
100+ static __global__ void upscale_f32_bicubic (const float * x, float * dst,
101+ const int nb00, const int nb01, const int nb02, const int nb03,
102+ const int ne00_src, const int ne01_src,
103+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
104+ const float sf0, const float sf1, const float sf2, const float sf3,
105+ const float pixel_offset) {
106+ using bicubic_interpolation::bicubic;
107+
108+ const int64_t index = threadIdx .x + blockIdx .x * blockDim .x ;
109+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
110+
111+ if (index >= dst_total_elements) {
112+ return ;
113+ }
114+
115+ const int i10_dst = index % ne10_dst;
116+ const int i11_dst = (index / ne10_dst) % ne11_dst;
117+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
118+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
119+
120+ const int i02_src = (int )(i12_dst / sf2);
121+ const int i03_src = (int )(i13_dst / sf3);
122+
123+ const float y_src_f = ((float )i11_dst + pixel_offset) / sf1 - pixel_offset;
124+ const int y0_src = (int )floorf (y_src_f);
125+ const float dy = y_src_f - (float )y0_src;
126+
127+ const float x_src_f = ((float )i10_dst + pixel_offset) / sf0 - pixel_offset;
128+ const int x0_src = (int )floorf (x_src_f);
129+ const float dx = x_src_f - (float )x0_src;
130+
131+ const char * x_base = (const char *)x + (int64_t )i02_src * nb02 + (int64_t )i03_src * nb03;
132+
133+ auto load = [=](int x_off, int y_off) -> float {
134+ int i00_src = max (0 , min (x0_src + x_off, ne00_src - 1 ));
135+ int i01_src = max (0 , min (y0_src + y_off, ne01_src - 1 ));
136+ return *(const float *)(x_base + (int64_t )i00_src * nb00 + (int64_t )i01_src * nb01);
137+ };
138+
139+ const float result = bicubic (
140+ bicubic (load (-1 ,-1 ), load (0 ,-1 ), load (1 ,-1 ), load (2 ,-1 ), dx),
141+ bicubic (load (-1 , 0 ), load (0 , 0 ), load (1 , 0 ), load (2 , 0 ), dx),
142+ bicubic (load (-1 , 1 ), load (0 , 1 ), load (1 , 1 ), load (2 , 1 ), dx),
143+ bicubic (load (-1 , 2 ), load (0 , 2 ), load (1 , 2 ), load (2 , 2 ), dx), dy);
144+
145+ dst[index] = result;
146+ }
147+
84148static void upscale_f32_cuda (const float * x, float * dst,
85149 const int nb00, const int nb01, const int nb02, const int nb03,
86150 const int ne10, const int ne11, const int ne12, const int ne13,
@@ -104,6 +168,18 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
104168 upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
105169}
106170
171+ static void upscale_f32_bicubic_cuda (const float * x, float * dst,
172+ const int nb00, const int nb01, const int nb02, const int nb03,
173+ const int ne00_src, const int ne01_src,
174+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
175+ const float sf0, const float sf1, const float sf2, const float sf3,
176+ const float pixel_offset, cudaStream_t stream) {
177+ const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
178+ const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
179+
180+ upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
181+ }
182+
107183void ggml_cuda_op_upscale (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
108184 const ggml_tensor * src0 = dst->src [0 ];
109185 const float * src0_d = (const float *)src0->data ;
@@ -121,17 +197,22 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
121197 float sf2 = (float )dst->ne [2 ]/src0->ne [2 ];
122198 const float sf3 = (float )dst->ne [3 ]/src0->ne [3 ];
123199
200+ float pixel_offset = 0 .5f ;
201+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
202+ sf0 = dst->ne [0 ] > 1 && src0->ne [0 ] > 1 ? (float )(dst->ne [0 ] - 1 ) / (src0->ne [0 ] - 1 ) : sf0;
203+ sf1 = dst->ne [1 ] > 1 && src0->ne [1 ] > 1 ? (float )(dst->ne [1 ] - 1 ) / (src0->ne [1 ] - 1 ) : sf1;
204+ pixel_offset = 0 .0f ;
205+ }
206+
124207 if (mode == GGML_SCALE_MODE_NEAREST) {
125208 upscale_f32_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], sf0, sf1, sf2, sf3, stream);
126209 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
127- float pixel_offset = 0 .5f ;
128- if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
129- sf0 = dst->ne [0 ] > 1 && src0->ne [0 ] > 1 ? (float )(dst->ne [0 ] - 1 ) / (src0->ne [0 ] - 1 ) : sf0;
130- sf1 = dst->ne [1 ] > 1 && src0->ne [1 ] > 1 ? (float )(dst->ne [1 ] - 1 ) / (src0->ne [1 ] - 1 ) : sf1;
131- pixel_offset = 0 .0f ;
132- }
133210 upscale_f32_bilinear_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
134211 src0->ne [0 ], src0->ne [1 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
135212 sf0, sf1, sf2, sf3, pixel_offset, stream);
213+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
214+ upscale_f32_bicubic_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
215+ src0->ne [0 ], src0->ne [1 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
216+ sf0, sf1, sf2, sf3, pixel_offset, stream);
136217 }
137218}
0 commit comments