#ifndef CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
|
#define CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
|
|
#include "caffe2/core/context.h"
|
#include "caffe2/core/operator.h"
|
#include "caffe2/utils/math.h"
|
|
namespace caffe2 {
|
|
template <class Context>
|
class SparseToDenseOp final : public Operator<Context> {
|
public:
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
USE_DISPATCH_HELPER;
|
|
template <class... Args>
|
explicit SparseToDenseOp(Args&&... args)
|
: Operator<Context>(std::forward<Args>(args)...),
|
output_first_dim_(
|
this->template GetSingleArgument<int>("output_first_dim", 0)) {}
|
|
bool RunOnDevice() override {
|
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
|
this, Input(INDICES));
|
}
|
|
private:
|
template <typename TInd>
|
int GetOutputFirstDim(
|
const TInd* sparse_indices_vec,
|
const int32_t sparse_indices_len) {
|
if (output_first_dim_ > 0) {
|
CAFFE_ENFORCE_EQ(InputSize(), 2);
|
return output_first_dim_;
|
}
|
if (InputSize() == 3) {
|
auto& data_to_infer_dim = Input(DATA_TO_INFER_DIM);
|
CAFFE_ENFORCE_GE(data_to_infer_dim.dim(), 1);
|
return data_to_infer_dim.dim32(0);
|
}
|
if (sparse_indices_len <= 0) {
|
return 0;
|
}
|
|
// Awkward way to get the max element to make it work with both CUDA
|
// and CPU.
|
ReinitializeTensor(&max_element_, {1}, at::dtype<TInd>().device(Context::GetDeviceType()));
|
TInd* max_element_ptr = max_element_.template mutable_data<TInd>();
|
math::ReduceMax<TInd>(sparse_indices_len, sparse_indices_vec, max_element_ptr,
|
&scratch_, &context_);
|
max_element_host_.CopyFrom(max_element_);
|
return 1 + max_element_host_.template data<TInd>()[0];
|
}
|
|
template <typename TInd>
|
bool DoRunWithType() {
|
return DispatchHelper<
|
TensorTypes2<
|
float,
|
int32_t,
|
int64_t,
|
GenericTensorImplementation>,
|
TInd>::call(this, Input(VALUES));
|
}
|
|
template <typename TInd, typename TData>
|
bool DoRunWithType2() {
|
auto& sparse_indices = Input(INDICES);
|
CAFFE_ENFORCE_EQ(sparse_indices.dim(), 1);
|
auto& sparse_values = Input(VALUES);
|
CAFFE_ENFORCE_GE(sparse_values.dim(), 1);
|
CAFFE_ENFORCE_EQ(sparse_indices.numel(), sparse_values.size(0));
|
|
const TInd* sparse_indices_vec = sparse_indices.template data<TInd>();
|
const int32_t sparse_indices_len = sparse_indices.dim32(0);
|
const int output_first_dim =
|
GetOutputFirstDim(sparse_indices_vec, sparse_indices_len);
|
|
auto shape = sparse_values.sizes().vec();
|
shape[0] = output_first_dim;
|
|
auto* output = Output(0, shape, at::dtype<TData>());
|
|
TData* output_data = output->template mutable_data<TData>();
|
if (!output_first_dim) {
|
return true;
|
}
|
memset(output_data, 0, output->nbytes());
|
const auto block_nitems = sparse_values.size_from_dim(1);
|
const TData* sparse_values_vec = sparse_values.template data<TData>();
|
|
for (int32_t i = 0; i < sparse_indices_len; i++) {
|
const TInd idx = sparse_indices_vec[i];
|
CAFFE_ENFORCE_GE(idx, 0);
|
CAFFE_ENFORCE_LT(idx, output_first_dim);
|
math::Add(
|
block_nitems,
|
output_data + idx * block_nitems,
|
sparse_values_vec + i * block_nitems,
|
output_data + idx * block_nitems,
|
&context_);
|
}
|
return true;
|
}
|
|
template <typename TInd>
|
bool DoRunWithOtherType2() {
|
CAFFE_THROW(
|
"SparseToDense is not implemented on tensor of type ",
|
Input(VALUES).dtype().name(),
|
"consider adding it as a type in the DispatchHelper list or "
|
"implementing a generic version (which won't work for "
|
"duplicated indices though)");
|
}
|
|
private:
|
int output_first_dim_;
|
Tensor scratch_{Context::GetDeviceType()};
|
Tensor max_element_host_{CPU};
|
Tensor max_element_;
|
|
INPUT_TAGS(INDICES, VALUES, DATA_TO_INFER_DIM);
|
};
|
|
} // namespace caffe2
|
|
#endif // CAFFE2_OPERATORS_SPARSE_TO_DENSE_OP_H_
|