#ifndef CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_ #define CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" namespace caffe2 { template class BatchBoxCoxOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit BatchBoxCoxOp(Args&&... args) : Operator(std::forward(args)...), min_block_size_( this->template GetSingleArgument("min_block_size", 256)) {} bool RunOnDevice() override { return DispatchHelper>::call(this, Input(DATA)); } template bool DoRunWithType(); protected: template void BoxCoxNaive( int64_t N, int64_t D, const T* data_ptr, const T* lambda1_ptr, const T* lambda2_ptr, T k_eps, T* output_ptr); #ifdef CAFFE2_USE_MKL template void BoxCoxNonzeroLambda( int64_t D, const T* data_ptr, const T* lambda1, const T* lambda2, T k_eps, T* output_ptr); template void BoxCoxZeroLambda( int64_t D, const T* data_ptr, const T* lambda2, T k_eps, T* output_ptr); template void BoxCoxMixedLambda( const T* data_ptr, const vector& nonzeros, const vector& zeros, const T* lambda1, const T* lambda2, const T* lambda2_z, T k_eps, T* buffer, T* output_ptr); vector nonzeros_, zeros_; // Buffers used by the MKL version are cached across calls. struct CachedBuffers { virtual ~CachedBuffers() {} int type_; }; template struct TypedCachedBuffers : public CachedBuffers { vector lambda1_, lambda2_, lambda2_z_; vector accumulator_; }; template TypedCachedBuffers& GetBuffers(); unique_ptr buffers_; #endif // CAFFE2_USE_MKL int min_block_size_; INPUT_TAGS(DATA, LAMBDA1, LAMBDA2); }; } // namespace caffe2 #endif // CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_