11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
24 template<
typename XprType>
25 struct traits<TensorIndexTupleOp<XprType> > :
public traits<XprType>
27 typedef traits<XprType> XprTraits;
28 typedef typename XprTraits::StorageKind StorageKind;
29 typedef typename XprTraits::Index Index;
30 typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
31 typedef typename XprType::Nested Nested;
32 typedef typename remove_reference<Nested>::type _Nested;
33 static const int NumDimensions = XprTraits::NumDimensions;
34 static const int Layout = XprTraits::Layout;
37 template<
typename XprType>
38 struct eval<TensorIndexTupleOp<XprType>,
Eigen::Dense>
40 typedef const TensorIndexTupleOp<XprType>& type;
43 template<
typename XprType>
44 struct nested<TensorIndexTupleOp<XprType>, 1,
45 typename eval<TensorIndexTupleOp<XprType> >::type>
47 typedef TensorIndexTupleOp<XprType> type;
52 template<
typename XprType>
53 class TensorIndexTupleOp :
public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
56 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58 typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
59 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
60 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
61 typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(
const XprType& expr)
67 const typename internal::remove_all<typename XprType::Nested>::type&
68 expression()
const {
return m_xpr; }
71 typename XprType::Nested m_xpr;
75 template<
typename ArgType,
typename Device>
76 struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
78 typedef TensorIndexTupleOp<ArgType> XprType;
79 typedef typename XprType::Index Index;
80 typedef typename XprType::Scalar Scalar;
81 typedef typename XprType::CoeffReturnType CoeffReturnType;
83 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
84 static const int NumDims = internal::array_size<Dimensions>::value;
90 Layout = TensorEvaluator<ArgType, Device>::Layout,
94 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
95 : m_impl(op.expression(), device) { }
97 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
98 return m_impl.dimensions();
101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
102 m_impl.evalSubExprsIfNeeded(NULL);
105 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
109 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const
111 return CoeffReturnType(index, m_impl.coeff(index));
114 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
117 TensorEvaluator<ArgType, Device> m_impl;
128 template<
typename ReduceOp,
typename Dims,
typename XprType>
129 struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > :
public traits<XprType>
131 typedef traits<XprType> XprTraits;
132 typedef typename XprTraits::StorageKind StorageKind;
133 typedef typename XprTraits::Index Index;
134 typedef Index Scalar;
135 typedef typename XprType::Nested Nested;
136 typedef typename remove_reference<Nested>::type _Nested;
137 static const int NumDimensions = XprTraits::NumDimensions;
138 static const int Layout = XprTraits::Layout;
141 template<
typename ReduceOp,
typename Dims,
typename XprType>
142 struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>,
Eigen::Dense>
144 typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type;
147 template<
typename ReduceOp,
typename Dims,
typename XprType>
148 struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
149 typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
151 typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
156 template<
typename ReduceOp,
typename Dims,
typename XprType>
157 class TensorTupleReducerOp :
public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
160 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
161 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
162 typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
163 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
164 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
165 typedef Index CoeffReturnType;
167 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(
const XprType& expr,
168 const ReduceOp& reduce_op,
169 const int return_dim,
170 const Dims& reduce_dims)
171 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
174 const typename internal::remove_all<typename XprType::Nested>::type&
175 expression()
const {
return m_xpr; }
178 const ReduceOp& reduce_op()
const {
return m_reduce_op; }
181 const Dims& reduce_dims()
const {
return m_reduce_dims; }
184 int return_dim()
const {
return m_return_dim; }
187 typename XprType::Nested m_xpr;
188 const ReduceOp m_reduce_op;
189 const int m_return_dim;
190 const Dims m_reduce_dims;
194 template<
typename ReduceOp,
typename Dims,
typename ArgType,
typename Device>
195 struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
197 typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
198 typedef typename XprType::Index Index;
199 typedef typename XprType::Scalar Scalar;
200 typedef typename XprType::CoeffReturnType CoeffReturnType;
201 typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
202 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
203 typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
204 static const int NumDims = internal::array_size<InputDimensions>::value;
205 typedef array<Index, NumDims> StrideDims;
209 PacketAccess =
false,
211 Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
215 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
216 : m_orig_impl(op.expression(), device),
217 m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
218 m_return_dim(op.return_dim()) {
220 gen_strides(m_orig_impl.dimensions(), m_strides);
221 if (Layout == static_cast<int>(ColMajor)) {
222 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
223 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
225 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
226 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
228 m_stride_div = m_strides[m_return_dim];
231 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
232 return m_impl.dimensions();
235 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
236 m_impl.evalSubExprsIfNeeded(NULL);
239 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
243 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
244 const TupleType v = m_impl.coeff(index);
245 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
248 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
251 EIGEN_DEVICE_FUNC
void gen_strides(
const InputDimensions& dims, StrideDims& strides) {
252 if (m_return_dim < 0) {
255 eigen_assert(m_return_dim < NumDims &&
256 "Asking to convert index to a dimension outside of the rank");
260 if (Layout == static_cast<int>(ColMajor)) {
262 for (
int i = 1; i < NumDims; ++i) {
263 strides[i] = strides[i-1] * dims[i-1];
266 strides[NumDims-1] = 1;
267 for (
int i = NumDims - 2; i >= 0; --i) {
268 strides[i] = strides[i+1] * dims[i+1];
274 TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
275 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
276 const int m_return_dim;
277 StrideDims m_strides;
284 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13