派生自 Algorithm/baseDetector

sunty
2022-03-21 d0a24896f95b4e060011852f80048ebfb0bf5f55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#ifndef _DETECT_H_
#define _DETECT_H_
 
#include <string>
#include <vector>
#include "NvInfer.h"
 
namespace nvinfer1
{
    template <typename T>
    void write(char*& buffer, const T& val)
    {
        *reinterpret_cast<T*>(buffer) = val;
        buffer += sizeof(T);
    }
 
    template <typename T>
    void read(const char*& buffer, T& val)
    {
        val = *reinterpret_cast<const T*>(buffer);
        buffer += sizeof(T);
    }
 
    class Detect :public IPluginV2
    {
    public:
        Detect();
        Detect(const void* data, size_t length);
        Detect(const uint32_t n_anchor_, const uint32_t _n_classes_,
            const uint32_t n_grid_h_, const uint32_t n_grid_w_/*,
            const uint32_t &n_stride_h_, const uint32_t &n_stride_w_*/);
        ~Detect();
        int getNbOutputs()const noexcept override
        {
            return 1;
        }
        Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept  override
        {
            return inputs[0];
        }
        int initialize() noexcept  override
        {
            return 0;
        }
        void terminate() noexcept  override
        {
        }
        size_t getWorkspaceSize(int maxBatchSize) const noexcept  override
        {
            return 0;
        }
        int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace,
            cudaStream_t stream) noexcept override;
               
 
        bool supportsFormat(DataType type, PluginFormat format) const noexcept override;
        void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override;
 
        size_t getSerializationSize() const noexcept  override;
        void serialize(void* buffer) const noexcept  override;
        const char* getPluginType() const noexcept  override
        {
            return "DETECT_TRT";
        }
        const char* getPluginVersion() const noexcept  override
        {
            return "1.0";
        }
        void destroy() noexcept  override
        {
            delete this;
        }
        void setPluginNamespace(const char* pluginNamespace) noexcept  override
        {
            _s_plugin_namespace = pluginNamespace;
        }
        const char* getPluginNamespace() const  noexcept override
        {
            return _s_plugin_namespace.c_str();
        }
        DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
        {
            return DataType::kFLOAT;
        }
        bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const noexcept
        {
            return false;
        }
        bool canBroadcastInputAcrossBatch(int inputIndex) const noexcept
        {
            return false;
        }
        void attachToContext(
            cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
        {}
        void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) ;
        void detachFromContext()
        {}
        bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const noexcept
        {
            return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
        }
        IPluginV2* clone() const noexcept override;
    private:
        
        uint32_t _n_anchor;
        uint32_t _n_classes;
        uint32_t _n_grid_h;
        uint32_t _n_grid_w;
        //uint32_t _n_stride_h;
    //    uint32_t _n_stride_w;
        uint64_t _n_output_size;
        std::string _s_plugin_namespace;
    }; //end detect
 
    class DetectPluginCreator : public IPluginCreator
    {
    public:
        DetectPluginCreator();
        ~DetectPluginCreator() override = default;
        const char* getPluginName()const noexcept  override;
        const char* getPluginVersion() const  noexcept override;
        const PluginFieldCollection* getFieldNames() noexcept  override;
        IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept  override;
        IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept  override;
        void setPluginNamespace(const char* libNamespace)  noexcept override;
        const char* getPluginNamespace() const noexcept  override;
    private:
        std::string _s_name_space;
        static PluginFieldCollection _fc;
        static std::vector<PluginField> _vec_plugin_attributes;
    };//end detect creator
 
}//end namespace nvinfer1
 
 
 
#endif