Skip to content

Commit eeefdf4

Browse files
committed
feat(webgl): add slice op
1 parent c0b8f6a commit eeefdf4

File tree

5 files changed

+196
-7
lines changed

5 files changed

+196
-7
lines changed

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ import pool2d_avg from './shader/pool2d_avg';
5757
import density_prior_box from './shader/density_prior_box';
5858
import box_coder from './shader/box_coder';
5959
import prior_box from './shader/prior_box';
60+
import stack from './shader/stack';
61+
import slice from './shader/slice';
6062

6163
import {
6264
imgFeed, pack_out, nhwc_2_nchw, unpacked_2_packed,
@@ -125,6 +127,7 @@ const ops = {
125127
pow: dynamic('pow'),
126128
sqrt: dynamic('sqrt'),
127129
tanh: dynamic('tanh'),
130+
exp: dynamic('exp'),
128131
squeeze2,
129132
pad3d,
130133
bilinear_interp_v2,
@@ -135,7 +138,9 @@ const ops = {
135138
imgFeed,
136139
box_coder,
137140
density_prior_box,
138-
prior_box
141+
prior_box,
142+
stack,
143+
slice
139144
};
140145
export {
141146
ops
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/**
2+
* @file slice
3+
* @description slice op,目前 case 很单一,暂时仅支持遇到的 case:slice final dim
4+
* @example x = [[1,2,3,4],[5,6,7,8]] axes=[1] starts=[2] ends=[3] => out [3,7]
5+
*/
6+
7+
function mainFunc(
8+
{
9+
out, origin
10+
},
11+
{
12+
axes,
13+
starts,
14+
ends,
15+
decrease_axis
16+
}
17+
) {
18+
19+
if (
20+
axes.length > 1
21+
|| starts.length > 1
22+
|| ends.length > 1
23+
|| (decrease_axis && decrease_axis.length === 0)
24+
) {
25+
throw Error('[slice op feature]: current support one dim, support decrease_axis');
26+
}
27+
const {
28+
width_shape,
29+
height_shape,
30+
channel,
31+
total_shape,
32+
length_unformatted_shape
33+
} = origin;
34+
const batch = total_shape / (width_shape * height_shape * channel);
35+
const tensor_shape = [batch, channel, height_shape, width_shape];
36+
37+
let axis = axes[0];
38+
39+
if (axis < 0) {
40+
axis = axis + length_unformatted_shape + 1;
41+
}
42+
axis = 4 - length_unformatted_shape + axis;
43+
44+
if (axis !== 4) {
45+
throw Error('[slice op feature]: unsupport axis value');
46+
}
47+
48+
const start = starts[0];
49+
const end = ends[0];
50+
51+
const [
52+
batch_num,
53+
channel_num,
54+
height_num,
55+
width_num
56+
] = tensor_shape;
57+
58+
// 计算 output tensor value 对应的 origin index
59+
const res_pos = [];
60+
for (let index = start; index < end; index++) {
61+
for (let batch = 0; batch < batch_num; batch++) {
62+
for (let channel = 0; channel < channel_num; channel++) {
63+
for (let height = 0; height < height_num; height++) {
64+
res_pos.push(
65+
batch * channel_num * height_num * width_num
66+
+ channel * height_num * width_num
67+
+ height * width_num + index
68+
);
69+
}
70+
}
71+
}
72+
}
73+
// 生成 glsl arr => int arr = int[](x, x, x,... x);
74+
const glslIndexArr = res_pos.reduce((acc, cur, index) => {
75+
const tmp = index < res_pos.length - 1 ? `${cur}, ` : `${cur});`;
76+
return acc + tmp;
77+
}, 'int arr[] = int[](');
78+
79+
return `
80+
void main(void) {
81+
ivec4 oPos = getOutputTensorPos();
82+
${glslIndexArr}
83+
// 输出坐标转换为输入坐标
84+
int sumVal = oPos.a
85+
+ oPos.b * ${out.width_shape}
86+
+ oPos.g * ${out.height_shape} * ${out.width_shape}
87+
+ oPos.r * ${out.channel} * ${out.width_shape} * ${out.height_shape};
88+
89+
int index = arr[sumVal];
90+
91+
float res = 0.0;
92+
ivec4 co = getTensorPosFromArrayIndex_origin(index);
93+
res = getValueFromTensorPos_origin(co.r, co.g, co.b, co.a);
94+
setOutput(float(res));
95+
}
96+
`;
97+
}
98+
export default {
99+
mainFunc,
100+
textureFuncConf: {
101+
origin: ['getValueFromTensorPos', 'getTensorPosFromArrayIndex']
102+
}
103+
};

packages/paddlejs-backend-webgl/src/ops/shader/stack.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* @detail https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/stack_cn.html#stack
66
*/
77

8-
/* eslint-disable max-lines */
8+
99
function mainFunc(
1010
{ out, ...inputs },
1111
attrs
@@ -21,7 +21,7 @@ function mainFunc(
2121
const batch = total_shape / (width_shape * height_shape * channel);
2222

2323
const tensor_shape = [batch, channel, height_shape, width_shape];
24-
const origin_shape = tensor_shape.slice(length_unformatted_shape);
24+
const origin_shape = tensor_shape.slice(4 - length_unformatted_shape);
2525

2626
const inputs_num = Object.keys(inputs).length;
2727

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
{
2+
"chunkNum": 0,
3+
"ops": [
4+
{
5+
"attrs": {},
6+
"inputs": {
7+
"X": [
8+
"feed"
9+
]
10+
},
11+
"outputs": {
12+
"Out": [
13+
"image"
14+
]
15+
},
16+
"type": "feed"
17+
},
18+
{
19+
"attrs": {
20+
"axes": [-1],
21+
"starts": [1],
22+
"ends": [2]
23+
},
24+
"inputs": {
25+
"X": [
26+
"slice_in"
27+
]
28+
},
29+
"outputs": {
30+
"Out": [
31+
"slice_out"
32+
]
33+
},
34+
"type": "slice"
35+
},
36+
{
37+
"attrs": {
38+
"op_device": ""
39+
},
40+
"inputs": {
41+
"X": [
42+
"slice_out"
43+
]
44+
},
45+
"outputs": {
46+
"Out": [
47+
"fetch"
48+
]
49+
},
50+
"type": "fetch"
51+
}
52+
],
53+
"vars": [
54+
{
55+
"name": "slice_in",
56+
"shape": [
57+
1,
58+
6,
59+
4
60+
],
61+
"data": [
62+
0.1924618, -1.1302841,-0.991921,-0.41128427,
63+
0.99852794, -0.16766216, -0.8959643, -0.63337845,
64+
0.1924618, -12.1302841,-0.991921,-0.41128427,
65+
0.99852794, -20.16766216, -0.8959643, -0.63337845,
66+
0.1924618, -31.1302841,-0.991921,-0.41128427,
67+
0.99852794, -30.16766216, -0.8959643, -0.63337845
68+
],
69+
"persistable": true
70+
},
71+
{
72+
"name": "slice_out",
73+
"shape": [1, 6],
74+
"data": [
75+
-1.130284070968628, -0.16766215860843658,
76+
-12.130284309387207, -20.167661666870117,
77+
-31.13028335571289, -30.167661666870117
78+
]
79+
}
80+
]
81+
}

packages/paddlejs-backend-webgl/test/op/opTest.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { Runner } from '@paddlejs/paddlejs-core';
22
import { glInstance } from '../../src/index';
33

4-
const opName = 'prior_box';
4+
const opName = 'slice';
55
const modelDir = '/test/op/data/';
66
const modelPath = `${modelDir}${opName}.json`;
77

@@ -10,13 +10,13 @@ async function run() {
1010
const runner = new Runner({
1111
modelPath,
1212
feedShape: {
13-
fw: 6,
14-
fh: 9
13+
fw: 3,
14+
fh: 3
1515
},
1616
needPreheat: false
1717
});
1818
await runner.init();
19-
const executeOP = runner.weightMap[0];
19+
const executeOP = runner.weightMap[1];
2020
runner.executeOp(executeOP);
2121
const res = await glInstance.read();
2222
console.log(res);

0 commit comments

Comments
 (0)