// Copyright 2004-present Facebook. All Rights Reserved.
|
|
#ifndef CAFFE2_OPERATORS_UTILS_EIGEN_H_
|
#define CAFFE2_OPERATORS_UTILS_EIGEN_H_
|
|
#include "Eigen/Core"
|
#include "Eigen/Dense"
|
|
#include "caffe2/core/logging.h"
|
|
namespace caffe2 {
|
|
// Common Eigen types that we will often use
|
template <typename T>
|
using EigenMatrixMap =
|
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
template <typename T>
|
using EigenArrayMap =
|
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
template <typename T>
|
using EigenVectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
template <typename T>
|
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
|
template <typename T>
|
using ConstEigenMatrixMap =
|
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
template <typename T>
|
using ConstEigenArrayMap =
|
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
|
template <typename T>
|
using ConstEigenVectorMap =
|
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
template <typename T>
|
using ConstEigenVectorArrayMap =
|
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
|
|
using EigenOuterStride = Eigen::OuterStride<Eigen::Dynamic>;
|
using EigenInnerStride = Eigen::InnerStride<Eigen::Dynamic>;
|
using EigenStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
|
template <typename T>
|
using EigenOuterStridedMatrixMap = Eigen::
|
Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>, 0, EigenOuterStride>;
|
template <typename T>
|
using EigenOuterStridedArrayMap = Eigen::
|
Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>, 0, EigenOuterStride>;
|
template <typename T>
|
using ConstEigenOuterStridedMatrixMap = Eigen::Map<
|
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>,
|
0,
|
EigenOuterStride>;
|
template <typename T>
|
using ConstEigenOuterStridedArrayMap = Eigen::Map<
|
const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>,
|
0,
|
EigenOuterStride>;
|
template <typename T>
|
using EigenStridedMatrixMap = Eigen::
|
Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>, 0, EigenStride>;
|
template <typename T>
|
using EigenStridedArrayMap =
|
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>, 0, EigenStride>;
|
template <typename T>
|
using ConstEigenStridedMatrixMap = Eigen::
|
Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>, 0, EigenStride>;
|
template <typename T>
|
using ConstEigenStridedArrayMap = Eigen::
|
Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>, 0, EigenStride>;
|
|
// 1-d array
|
template <typename T>
|
using EArrXt = Eigen::Array<T, Eigen::Dynamic, 1>;
|
using EArrXf = Eigen::ArrayXf;
|
using EArrXd = Eigen::ArrayXd;
|
using EArrXi = Eigen::ArrayXi;
|
using EArrXb = EArrXt<bool>;
|
|
// 2-d array, column major
|
template <typename T>
|
using EArrXXt = Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>;
|
using EArrXXf = Eigen::ArrayXXf;
|
|
// 2-d array, row major
|
template <typename T>
|
using ERArrXXt =
|
Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
using ERArrXXf = ERArrXXt<float>;
|
|
// 1-d vector
|
template <typename T>
|
using EVecXt = Eigen::Matrix<T, Eigen::Dynamic, 1>;
|
using EVecXd = Eigen::VectorXd;
|
using EVecXf = Eigen::VectorXf;
|
|
// 1-d row vector
|
using ERVecXd = Eigen::RowVectorXd;
|
using ERVecXf = Eigen::RowVectorXf;
|
|
// 2-d matrix, column major
|
template <typename T>
|
using EMatXt = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
|
using EMatXd = Eigen::MatrixXd;
|
using EMatXf = Eigen::MatrixXf;
|
|
// 2-d matrix, row major
|
template <typename T>
|
using ERMatXt =
|
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
using ERMatXd = ERMatXt<double>;
|
using ERMatXf = ERMatXt<float>;
|
|
namespace utils {
|
|
template <typename T>
|
Eigen::Map<const EArrXt<T>> AsEArrXt(const std::vector<T>& arr) {
|
return {arr.data(), static_cast<int>(arr.size())};
|
}
|
template <typename T>
|
Eigen::Map<EArrXt<T>> AsEArrXt(std::vector<T>& arr) {
|
return {arr.data(), static_cast<int>(arr.size())};
|
}
|
|
// return a sub array of 'array' based on indices 'indices'
|
template <class Derived, class Derived1, class Derived2>
|
void GetSubArray(
|
const Eigen::ArrayBase<Derived>& array,
|
const Eigen::ArrayBase<Derived1>& indices,
|
Eigen::ArrayBase<Derived2>* out_array) {
|
CAFFE_ENFORCE_EQ(array.cols(), 1);
|
// using T = typename Derived::Scalar;
|
|
out_array->derived().resize(indices.size());
|
for (int i = 0; i < indices.size(); i++) {
|
DCHECK_LT(indices[i], array.size());
|
(*out_array)[i] = array[indices[i]];
|
}
|
}
|
|
// return a sub array of 'array' based on indices 'indices'
|
template <class Derived, class Derived1>
|
EArrXt<typename Derived::Scalar> GetSubArray(
|
const Eigen::ArrayBase<Derived>& array,
|
const Eigen::ArrayBase<Derived1>& indices) {
|
using T = typename Derived::Scalar;
|
EArrXt<T> ret(indices.size());
|
GetSubArray(array, indices, &ret);
|
return ret;
|
}
|
|
// return a sub array of 'array' based on indices 'indices'
|
template <class Derived>
|
EArrXt<typename Derived::Scalar> GetSubArray(
|
const Eigen::ArrayBase<Derived>& array,
|
const std::vector<int>& indices) {
|
return GetSubArray(array, AsEArrXt(indices));
|
}
|
|
// return 2d sub array of 'array' based on row indices 'row_indices'
|
template <class Derived, class Derived1, class Derived2>
|
void GetSubArrayRows(
|
const Eigen::ArrayBase<Derived>& array2d,
|
const Eigen::ArrayBase<Derived1>& row_indices,
|
Eigen::ArrayBase<Derived2>* out_array) {
|
out_array->derived().resize(row_indices.size(), array2d.cols());
|
|
for (int i = 0; i < row_indices.size(); i++) {
|
DCHECK_LT(row_indices[i], array2d.size());
|
out_array->row(i) =
|
array2d.row(row_indices[i]).template cast<typename Derived2::Scalar>();
|
}
|
}
|
|
// return indices of 1d array for elements evaluated to true
|
template <class Derived>
|
std::vector<int> GetArrayIndices(const Eigen::ArrayBase<Derived>& array) {
|
std::vector<int> ret;
|
for (int i = 0; i < array.size(); i++) {
|
if (array[i]) {
|
ret.push_back(i);
|
}
|
}
|
return ret;
|
}
|
|
} // namespace utils
|
} // namespace caffe2
|
|
#endif
|