Skip to content

Commit 76b26f0

Browse files
committed
Upload tensorrt custom ops (#39)
1 parent 7e4b1c4 commit 76b26f0

File tree

4 files changed

+577
-0
lines changed

4 files changed

+577
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
// Copyright (c) OpenMMLab. All rights reserved
2+
#include "trt_deform_conv_v3.hpp"
3+
4+
#include <assert.h>
5+
6+
#include <chrono>
7+
8+
#include "trt_deform_conv_v3_kernel.hpp"
9+
#include "trt_plugin_helper.hpp"
10+
#include "trt_serialize.hpp"
11+
using namespace nvinfer1;
12+
13+
namespace mmdeploy {
14+
namespace {
15+
static const char *PLUGIN_VERSION{"1"};
16+
static const char *PLUGIN_NAME{"TRTDCNv3"};
17+
} // namespace
18+
19+
TRTDCNv3::TRTDCNv3(const std::string &name, int kernel_h, int kernel_w, int stride_h, int stride_w,
20+
int pad_h, int pad_w, int dilation_h, int dilation_w, int group,
21+
int group_channels, float offset_scale, int im2col_step)
22+
: TRTPluginBase(name),
23+
kernel_h_(kernel_h),
24+
kernel_w_(kernel_w),
25+
stride_h_(stride_h),
26+
stride_w_(stride_w),
27+
pad_h_(pad_h),
28+
pad_w_(pad_w),
29+
dilation_h_(dilation_h),
30+
dilation_w_(dilation_w),
31+
group_(group),
32+
group_channels_(group_channels),
33+
offset_scale_(offset_scale),
34+
im2col_step_(im2col_step) {}
35+
36+
TRTDCNv3::TRTDCNv3(const std::string name, const void *data, size_t length) : TRTPluginBase(name) {
37+
deserialize_value(&data, &length, &kernel_h_);
38+
deserialize_value(&data, &length, &kernel_w_);
39+
deserialize_value(&data, &length, &stride_h_);
40+
deserialize_value(&data, &length, &stride_w_);
41+
deserialize_value(&data, &length, &pad_h_);
42+
deserialize_value(&data, &length, &pad_w_);
43+
deserialize_value(&data, &length, &dilation_h_);
44+
deserialize_value(&data, &length, &dilation_w_);
45+
deserialize_value(&data, &length, &group_);
46+
deserialize_value(&data, &length, &group_channels_);
47+
deserialize_value(&data, &length, &offset_scale_);
48+
deserialize_value(&data, &length, &im2col_step_);
49+
}
50+
51+
nvinfer1::IPluginV2DynamicExt *TRTDCNv3::clone() const TRT_NOEXCEPT {
52+
TRTDCNv3 *plugin =
53+
new TRTDCNv3(mLayerName, kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_,
54+
dilation_h_, dilation_w_, group_, group_channels_, offset_scale_, im2col_step_);
55+
plugin->setPluginNamespace(getPluginNamespace());
56+
57+
return plugin;
58+
}
59+
60+
const nvinfer1::IDimensionExpr *output_size(const nvinfer1::IDimensionExpr &input, int pad,
61+
int dilation, int kernel, int stride,
62+
nvinfer1::IExprBuilder &exprBuilder) {
63+
// out_expand = 2×padding[0]−dilation[0]×(kernel_size[0]−1)+1
64+
auto out_expand = exprBuilder.constant(2 * pad - dilation * (kernel - 1) + 1);
65+
// out = out_expand + input
66+
auto out_before_div = exprBuilder.operation(DimensionOperation::kSUM, input, *out_expand);
67+
// out = out / stride
68+
auto out_before_sub = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *out_before_div,
69+
*(exprBuilder.constant(stride)));
70+
// out -=1
71+
auto out =
72+
exprBuilder.operation(DimensionOperation::kSUB, *out_before_sub, *(exprBuilder.constant(1)));
73+
return out;
74+
}
75+
76+
nvinfer1::DimsExprs TRTDCNv3::getOutputDimensions(
77+
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
78+
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
79+
nvinfer1::DimsExprs ret;
80+
ret.nbDims = 4;
81+
ret.d[0] = inputs[0].d[0];
82+
ret.d[3] = exprBuilder.constant(group_ * group_channels_);
83+
84+
ret.d[1] = output_size(*inputs[0].d[1], pad_h_, dilation_h_, kernel_h_, stride_h_, exprBuilder);
85+
ret.d[2] = output_size(*inputs[0].d[2], pad_w_, dilation_w_, kernel_w_, stride_w_, exprBuilder);
86+
87+
return ret;
88+
}
89+
90+
bool TRTDCNv3::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
91+
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
92+
if (pos == 0) {
93+
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
94+
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
95+
96+
} else {
97+
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
98+
}
99+
}
100+
101+
void TRTDCNv3::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
102+
const nvinfer1::DynamicPluginTensorDesc *outputs,
103+
int nbOutputs) TRT_NOEXCEPT {}
104+
105+
size_t TRTDCNv3::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
106+
const nvinfer1::PluginTensorDesc *outputs,
107+
int nbOutputs) const TRT_NOEXCEPT {
108+
int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type);
109+
110+
int batch_size = inputs[0].dims.d[0];
111+
int nInputPlane = inputs[0].dims.d[3];
112+
int inputHeight = inputs[0].dims.d[1];
113+
int inputWidth = inputs[0].dims.d[2];
114+
115+
int nOutputPlane = outputs[0].dims.d[3];
116+
int outputHeight = outputs[0].dims.d[1];
117+
int outputWidth = outputs[0].dims.d[2];
118+
119+
int kW = inputs[3].dims.d[1];
120+
int kH = inputs[3].dims.d[2];
121+
int im2col_step = std::min(32, batch_size);
122+
123+
size_t col_size =
124+
mmdeploy::getAlignedSize(nInputPlane * kW * kH * outputHeight * outputWidth * sizeof_dtype);
125+
126+
return col_size;
127+
}
128+
129+
int TRTDCNv3::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
130+
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
131+
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
132+
int batch = inputDesc[0].dims.d[0];
133+
int height = inputDesc[0].dims.d[1];
134+
int width = inputDesc[0].dims.d[2];
135+
int channels = inputDesc[0].dims.d[3];
136+
137+
int height_out = outputDesc[0].dims.d[1];
138+
int width_out = outputDesc[0].dims.d[2];
139+
int channels_out = outputDesc[0].dims.d[3];
140+
141+
const void *input = inputs[0];
142+
const void *offset = inputs[1];
143+
const void *mask = inputs[2];
144+
void *output = outputs[0];
145+
146+
// TODO: add fp16 support
147+
auto data_type = inputDesc[0].type;
148+
switch (data_type) {
149+
case nvinfer1::DataType::kFLOAT:
150+
DeformConvv3ForwardCUDAKernelLauncher<float>(
151+
(float *)input, (float *)offset, (float *)mask, (float *)output, workSpace, batch,
152+
channels, height, width, channels_out, kernel_w_, kernel_h_, stride_w_, stride_h_, pad_w_,
153+
pad_h_, dilation_w_, dilation_h_, group_, group_channels_, offset_scale_, im2col_step_,
154+
stream);
155+
break;
156+
default:
157+
return 1;
158+
break;
159+
}
160+
161+
return 0;
162+
}
163+
164+
nvinfer1::DataType TRTDCNv3::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
165+
int nbInputs) const TRT_NOEXCEPT {
166+
return inputTypes[0];
167+
}
168+
169+
// IPluginV2 Methods
170+
const char *TRTDCNv3::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
171+
172+
const char *TRTDCNv3::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
173+
174+
int TRTDCNv3::getNbOutputs() const TRT_NOEXCEPT { return 1; }
175+
176+
size_t TRTDCNv3::getSerializationSize() const TRT_NOEXCEPT {
177+
return serialized_size(kernel_h_) + serialized_size(kernel_w_) + serialized_size(stride_h_) +
178+
serialized_size(stride_w_) + serialized_size(pad_h_) + serialized_size(pad_w_) +
179+
serialized_size(dilation_h_) + serialized_size(dilation_w_) + serialized_size(group_) +
180+
serialized_size(group_channels_) + serialized_size(offset_scale_) +
181+
serialized_size(im2col_step_);
182+
}
183+
184+
void TRTDCNv3::serialize(void *buffer) const TRT_NOEXCEPT {
185+
serialize_value(&buffer, kernel_h_);
186+
serialize_value(&buffer, kernel_w_);
187+
serialize_value(&buffer, stride_h_);
188+
serialize_value(&buffer, stride_w_);
189+
serialize_value(&buffer, pad_h_);
190+
serialize_value(&buffer, pad_w_);
191+
serialize_value(&buffer, dilation_h_);
192+
serialize_value(&buffer, dilation_w_);
193+
serialize_value(&buffer, group_);
194+
serialize_value(&buffer, group_channels_);
195+
serialize_value(&buffer, offset_scale_);
196+
serialize_value(&buffer, im2col_step_);
197+
}
198+
199+
////////////////////// creator /////////////////////////////
200+
201+
TRTDCNv3Creator::TRTDCNv3Creator() {
202+
mPluginAttributes.clear();
203+
mPluginAttributes.emplace_back(nvinfer1::PluginField("kernel_h"));
204+
mPluginAttributes.emplace_back(nvinfer1::PluginField("kernel_w"));
205+
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h"));
206+
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w"));
207+
mPluginAttributes.emplace_back(nvinfer1::PluginField("pad_h"));
208+
mPluginAttributes.emplace_back(nvinfer1::PluginField("pad_w"));
209+
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation_h"));
210+
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation_w"));
211+
mPluginAttributes.emplace_back(nvinfer1::PluginField("group"));
212+
mPluginAttributes.emplace_back(nvinfer1::PluginField("group_channels"));
213+
mPluginAttributes.emplace_back(nvinfer1::PluginField("offset_scale"));
214+
mPluginAttributes.emplace_back(nvinfer1::PluginField("im2col_step"));
215+
mFC.nbFields = mPluginAttributes.size();
216+
mFC.fields = mPluginAttributes.data();
217+
}
218+
219+
const char *TRTDCNv3Creator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }
220+
221+
const char *TRTDCNv3Creator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
222+
223+
nvinfer1::IPluginV2 *TRTDCNv3Creator::createPlugin(
224+
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
225+
nvinfer1::Dims size{2, {1, 1}};
226+
int kernel_h = 3;
227+
int kernel_w = 3;
228+
int stride_h = 1;
229+
int stride_w = 1;
230+
int pad_h = 1;
231+
int pad_w = 1;
232+
int dilation_h = 1;
233+
int dilation_w = 1;
234+
int group = 28;
235+
int group_channels = 16;
236+
float offset_scale = 1;
237+
int im2col_step = 256;
238+
239+
for (int i = 0; i < fc->nbFields; i++) {
240+
if (fc->fields[i].data == nullptr) {
241+
continue;
242+
}
243+
std::string field_name(fc->fields[i].name);
244+
245+
if (field_name.compare("kernel_h") == 0) {
246+
kernel_h = static_cast<const int *>(fc->fields[i].data)[0];
247+
}
248+
if (field_name.compare("kernel_w") == 0) {
249+
kernel_w = static_cast<const int *>(fc->fields[i].data)[0];
250+
}
251+
if (field_name.compare("stride_h") == 0) {
252+
stride_h = static_cast<const int *>(fc->fields[i].data)[0];
253+
}
254+
if (field_name.compare("stride_w") == 0) {
255+
stride_w = static_cast<const int *>(fc->fields[i].data)[0];
256+
}
257+
if (field_name.compare("pad_h") == 0) {
258+
pad_h = static_cast<const int *>(fc->fields[i].data)[0];
259+
}
260+
if (field_name.compare("pad_w") == 0) {
261+
pad_w = static_cast<const int *>(fc->fields[i].data)[0];
262+
}
263+
if (field_name.compare("dilation_h") == 0) {
264+
dilation_h = static_cast<const int *>(fc->fields[i].data)[0];
265+
}
266+
if (field_name.compare("dilation_w") == 0) {
267+
dilation_w = static_cast<const int *>(fc->fields[i].data)[0];
268+
}
269+
if (field_name.compare("group") == 0) {
270+
group = static_cast<const int *>(fc->fields[i].data)[0];
271+
}
272+
if (field_name.compare("group_channels") == 0) {
273+
group_channels = static_cast<const int *>(fc->fields[i].data)[0];
274+
}
275+
if (field_name.compare("offset_scale") == 0) {
276+
offset_scale = static_cast<const float *>(fc->fields[i].data)[0];
277+
}
278+
if (field_name.compare("im2col_step") == 0) {
279+
im2col_step = static_cast<const int *>(fc->fields[i].data)[0];
280+
}
281+
}
282+
283+
TRTDCNv3 *plugin =
284+
new TRTDCNv3(name, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h,
285+
dilation_w, group, group_channels, offset_scale, im2col_step);
286+
plugin->setPluginNamespace(getPluginNamespace());
287+
return plugin;
288+
}
289+
290+
nvinfer1::IPluginV2 *TRTDCNv3Creator::deserializePlugin(const char *name, const void *serialData,
291+
size_t serialLength) TRT_NOEXCEPT {
292+
auto plugin = new TRTDCNv3(name, serialData, serialLength);
293+
plugin->setPluginNamespace(getPluginNamespace());
294+
return plugin;
295+
}
296+
REGISTER_TENSORRT_PLUGIN(TRTDCNv3Creator);
297+
} // namespace mmdeploy
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#ifndef TRT_DEFORM_CONV_V3_HPP
2+
#define TRT_DEFORM_CONV_V3_HPP
3+
#include <cublas_v2.h>
4+
5+
#include <memory>
6+
#include <string>
7+
#include <vector>
8+
9+
#include "trt_plugin_base.hpp"
10+
namespace mmdeploy {
11+
class TRTDCNv3 : public TRTPluginBase {
12+
public:
13+
TRTDCNv3(const std::string &name, int kernel_h, int kernel_w, int stride_h, int stride_w,
14+
int pad_h, int pad_w, int dilation_h, int dilation_w, int group, int group_channels,
15+
float offset_scale, int im2col_step);
16+
17+
TRTDCNv3(const std::string name, const void *data, size_t length);
18+
19+
TRTDCNv3() = delete;
20+
21+
// IPluginV2DynamicExt Methods
22+
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
23+
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
24+
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
25+
TRT_NOEXCEPT override;
26+
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
27+
int nbOutputs) TRT_NOEXCEPT override;
28+
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
29+
const nvinfer1::DynamicPluginTensorDesc *out,
30+
int nbOutputs) TRT_NOEXCEPT override;
31+
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
32+
const nvinfer1::PluginTensorDesc *outputs,
33+
int nbOutputs) const TRT_NOEXCEPT override;
34+
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
35+
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
36+
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
37+
38+
// IPluginV2Ext Methods
39+
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
40+
int nbInputs) const TRT_NOEXCEPT override;
41+
42+
// IPluginV2 Methods
43+
const char *getPluginType() const TRT_NOEXCEPT override;
44+
const char *getPluginVersion() const TRT_NOEXCEPT override;
45+
int getNbOutputs() const TRT_NOEXCEPT override;
46+
size_t getSerializationSize() const TRT_NOEXCEPT override;
47+
void serialize(void *buffer) const TRT_NOEXCEPT override;
48+
49+
private:
50+
int kernel_h_;
51+
int kernel_w_;
52+
int stride_h_;
53+
int stride_w_;
54+
int pad_h_;
55+
int pad_w_;
56+
int dilation_h_;
57+
int dilation_w_;
58+
int group_;
59+
int group_channels_;
60+
float offset_scale_;
61+
int im2col_step_;
62+
};
63+
64+
class TRTDCNv3Creator : public TRTPluginCreatorBase {
65+
public:
66+
TRTDCNv3Creator();
67+
68+
const char *getPluginName() const TRT_NOEXCEPT override;
69+
70+
const char *getPluginVersion() const TRT_NOEXCEPT override;
71+
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
72+
TRT_NOEXCEPT override;
73+
74+
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
75+
size_t serialLength) TRT_NOEXCEPT override;
76+
};
77+
} // namespace mmdeploy
78+
#endif // TRT_DEFORM_CONV_V3_HPP

0 commit comments

Comments
 (0)