派生自 Algorithm/baseDetector

sunty
2022-03-21 d0a24896f95b4e060011852f80048ebfb0bf5f55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
 
#ifndef _HARDSWISH_H_
#define _HARDSWISH_H_ 
 
#include <string>
#include <vector>
#include "NvInfer.h"
 
namespace nvinfer1
{
    template <typename T>
    void w(char*& buffer, const T& val)
    {
        *reinterpret_cast<T*>(buffer) = val;
        buffer += sizeof(T);
    }
 
    template <typename T>
    void r(const char*& buffer, T& val)
    {
        val = *reinterpret_cast<const T*>(buffer);
        buffer += sizeof(T);
    }
 
    class Hardswish :public IPluginV2
    {
    public:
        Hardswish();
        Hardswish(const void* data, size_t length);
        ~Hardswish();
        int getNbOutputs()const noexcept  override
        {
            return 1;
        }
        Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept  override
        {
            return inputs[0];
        }
        int initialize() noexcept  override
        {
            return 0;
        }
        void terminate() noexcept  override
        {
        }
        size_t getWorkspaceSize(int maxBatchSize) const noexcept  override
        {
            return 0;
        }
 
        bool supportsFormat(DataType type, PluginFormat format) const noexcept override;
        void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override;
 
        int enqueue(int batchSize, const void* const* inputs, void* * outputs, void* workspace,
            cudaStream_t stream) noexcept override;
 
        size_t getSerializationSize() const noexcept  override;
        void serialize(void* buffer) const noexcept  override;
        const char* getPluginType() const noexcept  override
        {
            return "HARDSWISH_TRT";
        }
        const char* getPluginVersion() const noexcept  override
        {
            return "1.0";
        }
        void destroy() noexcept  override
        {
            delete this;
        }
        void setPluginNamespace(const char* pluginNamespace) noexcept  override
        {
            _s_plugin_namespace = pluginNamespace;
        }
        const char* getPluginNamespace() const noexcept  override
        {
            return _s_plugin_namespace.c_str();
        }
        DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
        {
            return DataType::kFLOAT;
        }
        bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const noexcept
        {
            return false;
        }
        bool canBroadcastInputAcrossBatch(int inputIndex) const noexcept
        {
            return false;
        }
        void attachToContext(
            cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) noexcept
        {}
        void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) noexcept;
        void detachFromContext()  noexcept
        {}
        bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const noexcept
        {
            return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
        }
        IPluginV2* clone() const  noexcept override;
    private:
 
        uint32_t _n_max_thread_pre_block;
        uint32_t _n_output_size;
        std::string _s_plugin_namespace;
    }; //end detect
 
    class HardswishPluginCreator : public IPluginCreator
    {
    public:
        HardswishPluginCreator();
        ~HardswishPluginCreator() override = default;
        const char* getPluginName()const noexcept  override;
        const char* getPluginVersion() const  noexcept override;
        const PluginFieldCollection* getFieldNames() noexcept  override;
        IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept  override;
        IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept  override;
        void setPluginNamespace(const char* libNamespace) noexcept  override;
        const char* getPluginNamespace() const noexcept  override;
    private:
        std::string _s_name_space;
        static PluginFieldCollection _fc;
        static std::vector<PluginField> _vec_plugin_attributes;
    };//end detect creator
 
}//end namespace nvinfer1
 
 
 
#endif