#ifndef CAFFE2_MPI_MPI_OPS_H_ #define CAFFE2_MPI_MPI_OPS_H_ #include #include "caffe2/core/operator.h" #include "caffe2/mpi/mpi_common.h" namespace caffe2 { // TODO(jiayq): if needed, write up the use of color and key with MPI split. // Currently, the operator simply creates a communicator that has the // same topology as the Caffe2 global communicator. template class MPICreateCommonWorldOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MPICreateCommonWorldOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws) {} bool RunOnDevice() override { OperatorBase::Outputs()[0]->Reset(new MPICommonWorldWrapper()); return true; } }; template class MPIBroadcastOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MPIBroadcastOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), root_(OperatorBase::template GetSingleArgument("root", 0)) {} ~MPIBroadcastOp() {} bool RunOnDevice() override { MPI_Comm comm = OperatorBase::Input(0).comm(); CAFFE_ENFORCE( OperatorBase::OutputIsTensorType(0, Context::GetDeviceType()), "Output is of wrong type."); auto* output = Output(0); // Make sure that output is already allocated. CAFFE_ENFORCE( output->numel() > 0, "Broadcast op uses in-place operation so the output " "should be already allocated."); MPI_CHECK(MPI_Bcast( output->raw_mutable_data(), output->nbytes(), MPIDataTypeWrapper::type(), root_, comm)); return true; } protected: int root_; }; // MPIReduceOp does Reduce using MPI. Currently, only SUM is supported. template class MPIReduceOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MPIReduceOp(const OperatorDef& operator_def, Workspace* ws) : Operator(operator_def, ws), root_(OperatorBase::template GetSingleArgument("root", 0)) {} ~MPIReduceOp() {} bool RunOnDevice() override { MPI_Comm comm = OperatorBase::Input(0).comm(); auto& input = Input(1); auto* output = Output(0, input.sizes(), at::dtype()); MPI_CHECK(MPI_Reduce( const_cast(input.template data()), output->template mutable_data(), input.numel(), MPIDataTypeWrapper::type(), MPI_SUM, root_, comm)); return true; } protected: int root_; }; // MPIAllgatherOp does MPIAllgather using MPI. template class MPIAllgatherOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(MPIAllgatherOp); bool RunOnDevice() override { MPI_Comm comm = OperatorBase::Input(0).comm(); auto& input = Input(1); auto* output = Output(0); vector output_dims = input.sizes().vec(); output_dims[0] *= OperatorBase::Input(0).size(); output->Resize(output_dims); MPI_CHECK(MPI_Allgather( const_cast(input.template data()), input.numel(), MPIDataTypeWrapper::type(), output->template mutable_data(), input.numel(), MPIDataTypeWrapper::type(), comm)); return true; } }; // MPIAllreduceOp does MPIAllreduce using MPI. Currently, only SUM is supported. template class MPIAllreduceOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; USE_SIMPLE_CTOR_DTOR(MPIAllreduceOp); bool RunOnDevice() override { MPI_Comm comm = OperatorBase::Input(0).comm(); auto& input = Input(1); auto* output = Output(0, input.sizes(), at::dtype()); void* source; if (output->template mutable_data() == input.template data()) { // We are doing in-place call. Special case handling. source = MPI_IN_PLACE; } else { // Normal allreduce takes the source from the input. source = const_cast(input.template data()); } MPI_CHECK(MPI_Allreduce( source, output->template mutable_data(), input.numel(), MPIDataTypeWrapper::type(), MPI_SUM, comm)); return true; } }; template class MPISendTensorOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MPISendTensorOp(const OperatorDef& def, Workspace* ws) : Operator(def, ws), OP_SINGLE_ARG(int, "dst", dst_, MPI_ANY_SOURCE), OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG), OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) { CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet."); CAFFE_ENFORCE( dst_ != MPI_ANY_SOURCE || def.input_size() == 4, "You should explicitly specify the to rank either via " "argument or via input blobs."); CAFFE_ENFORCE( tag_ != MPI_ANY_TAG || def.input_size() == 4, "You should explicitly specify the tag either via " "argument or via input blobs."); } bool RunOnDevice() override { MPI_Comm comm = OperatorBase::Input(COMM).comm(); auto& input = Input(INPUT); if (InputSize() == 4) { dst_ = OperatorBase::Input(DST, CPU).template data()[0]; tag_ = OperatorBase::Input(TAG, CPU).template data()[0]; } if (raw_buffer_) { // We need to do a const cast to cope with the fact that, before OpenMPI // 1.7, MPI_Send expects a non-const pointer although it uses it in a // const way. MPI_CHECK(MPI_Send( const_cast(input.raw_data()), input.nbytes(), MPI_CHAR, dst_, tag_, comm)); } else { CAFFE_NOT_IMPLEMENTED; } return true; } protected: int dst_; int tag_; bool raw_buffer_; INPUT_TAGS(COMM, INPUT, DST, TAG); }; template class MPIReceiveTensorOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; MPIReceiveTensorOp(const OperatorDef& def, Workspace* ws) : Operator(def, ws), OP_SINGLE_ARG(int, "src", src_, MPI_ANY_SOURCE), OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG), OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) { CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet."); } bool RunOnDevice() override { MPI_Comm comm = OperatorBase::Input(COMM).comm(); if (InputSize() == 4) { src_ = OperatorBase::Input(SRC_IN, CPU).template data()[0]; tag_ = OperatorBase::Input(TAG_IN, CPU).template data()[0]; } MPI_Status status; if (raw_buffer_) { auto* output = Output(OUTPUT); MPI_CHECK(MPI_Recv( output->raw_mutable_data(), output->nbytes(), MPI_CHAR, src_, tag_, comm, &status)); } else { CAFFE_NOT_IMPLEMENTED; } auto* src_out = OperatorBase::Output(SRC_OUT, CPU); src_out->Resize(); src_out->template mutable_data()[0] = status.MPI_SOURCE; auto* tag_out = OperatorBase::Output(TAG_OUT, CPU); tag_out->Resize(); tag_out->template mutable_data()[0] = status.MPI_TAG; return true; } protected: int src_; int tag_; bool raw_buffer_; INPUT_TAGS(COMM, INPUT, SRC_IN, TAG_IN); OUTPUT_TAGS(OUTPUT, SRC_OUT, TAG_OUT); }; } // namespace caffe2 #endif // CAFFE2_MPI_MPI_OPS_H_