10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
58 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
59 struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,
ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
61 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
64 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
65 && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
66 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
67 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
68 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
71 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
72 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
73 typedef typename packet_traits<ResScalar>::type _ResPacket;
75 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
76 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
77 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
79 EIGEN_DONT_INLINE
static void run(
80 Index rows, Index cols,
83 ResScalar* res, Index resIncr,
87 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
88 EIGEN_DONT_INLINE
void general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
89 Index rows, Index cols,
92 ResScalar* res, Index resIncr,
95 EIGEN_UNUSED_VARIABLE(resIncr);
96 eigen_internal_assert(resIncr==1);
97 #ifdef _EIGEN_ACCUMULATE_PACKETS
98 #error _EIGEN_ACCUMULATE_PACKETS has already been defined
100 #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) \
102 padd(pload<ResPacket>(&res[j]), \
104 padd(pcj.pmul(lhs0.template load<LhsPacket, Alignment0>(j), ptmp0), \
105 pcj.pmul(lhs1.template load<LhsPacket, Alignment13>(j), ptmp1)), \
106 padd(pcj.pmul(lhs2.template load<LhsPacket, Alignment2>(j), ptmp2), \
107 pcj.pmul(lhs3.template load<LhsPacket, Alignment13>(j), ptmp3)) )))
109 typedef typename LhsMapper::VectorMapper LhsScalars;
111 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
112 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
114 alpha = numext::conj(alpha);
116 enum { AllAligned = 0, EvenAligned, FirstAligned, NoneAligned };
117 const Index columnsAtOnce = 4;
118 const Index peels = 2;
119 const Index LhsPacketAlignedMask = LhsPacketSize-1;
120 const Index ResPacketAlignedMask = ResPacketSize-1;
122 const Index size = rows;
124 const Index lhsStride = lhs.stride();
128 Index alignedStart = internal::first_default_aligned(res,size);
129 Index alignedSize = ResPacketSize>1 ? alignedStart + ((size-alignedStart) & ~ResPacketAlignedMask) : 0;
130 const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
132 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
133 Index alignmentPattern = alignmentStep==0 ? AllAligned
134 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
138 const Index lhsAlignmentOffset = lhs.firstAligned(size);
141 Index skipColumns = 0;
143 if( (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == size) || (
size_t(res)%
sizeof(ResScalar)) )
147 alignmentPattern = NoneAligned;
149 else if(LhsPacketSize > 4)
153 alignmentPattern = NoneAligned;
155 else if (LhsPacketSize>1)
159 while (skipColumns<LhsPacketSize &&
160 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipColumns)%LhsPacketSize))
162 if (skipColumns==LhsPacketSize)
165 alignmentPattern = NoneAligned;
170 skipColumns = (std::min)(skipColumns,cols);
179 else if(Vectorizable)
183 alignmentPattern = AllAligned;
186 const Index offset1 = (FirstAligned && alignmentStep==1?3:1);
187 const Index offset3 = (FirstAligned && alignmentStep==1?1:3);
189 Index columnBound = ((cols-skipColumns)/columnsAtOnce)*columnsAtOnce + skipColumns;
190 for (Index i=skipColumns; i<columnBound; i+=columnsAtOnce)
192 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(i, 0)),
193 ptmp1 = pset1<RhsPacket>(alpha*rhs(i+offset1, 0)),
194 ptmp2 = pset1<RhsPacket>(alpha*rhs(i+2, 0)),
195 ptmp3 = pset1<RhsPacket>(alpha*rhs(i+offset3, 0));
198 const LhsScalars lhs0 = lhs.getVectorMapper(0, i+0), lhs1 = lhs.getVectorMapper(0, i+offset1),
199 lhs2 = lhs.getVectorMapper(0, i+2), lhs3 = lhs.getVectorMapper(0, i+offset3);
205 for (Index j=0; j<alignedStart; ++j)
207 res[j] = cj.pmadd(lhs0(j), pfirst(ptmp0), res[j]);
208 res[j] = cj.pmadd(lhs1(j), pfirst(ptmp1), res[j]);
209 res[j] = cj.pmadd(lhs2(j), pfirst(ptmp2), res[j]);
210 res[j] = cj.pmadd(lhs3(j), pfirst(ptmp3), res[j]);
213 if (alignedSize>alignedStart)
215 switch(alignmentPattern)
218 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
222 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
227 Index j = alignedStart;
230 LhsPacket A00, A01, A02, A03, A10, A11, A12, A13;
233 A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
234 A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
235 A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
237 for (; j<peeledSize; j+=peels*ResPacketSize)
239 A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
240 A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
241 A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
243 A00 = lhs0.template load<LhsPacket, Aligned>(j);
244 A10 = lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize);
245 T0 = pcj.pmadd(A00, ptmp0, pload<ResPacket>(&res[j]));
246 T1 = pcj.pmadd(A10, ptmp0, pload<ResPacket>(&res[j+ResPacketSize]));
248 T0 = pcj.pmadd(A01, ptmp1, T0);
249 A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
250 T0 = pcj.pmadd(A02, ptmp2, T0);
251 A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
252 T0 = pcj.pmadd(A03, ptmp3, T0);
254 A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
255 T1 = pcj.pmadd(A11, ptmp1, T1);
256 T1 = pcj.pmadd(A12, ptmp2, T1);
257 T1 = pcj.pmadd(A13, ptmp3, T1);
258 pstore(&res[j+ResPacketSize],T1);
261 for (; j<alignedSize; j+=ResPacketSize)
266 for (Index j = alignedStart; j<alignedSize; j+=ResPacketSize)
274 for (Index j=alignedSize; j<size; ++j)
276 res[j] = cj.pmadd(lhs0(j), pfirst(ptmp0), res[j]);
277 res[j] = cj.pmadd(lhs1(j), pfirst(ptmp1), res[j]);
278 res[j] = cj.pmadd(lhs2(j), pfirst(ptmp2), res[j]);
279 res[j] = cj.pmadd(lhs3(j), pfirst(ptmp3), res[j]);
285 Index start = columnBound;
288 for (Index k=start; k<end; ++k)
290 RhsPacket ptmp0 = pset1<RhsPacket>(alpha*rhs(k, 0));
291 const LhsScalars lhs0 = lhs.getVectorMapper(0, k);
297 for (Index j=0; j<alignedStart; ++j)
298 res[j] += cj.pmul(lhs0(j), pfirst(ptmp0));
300 if (lhs0.template aligned<LhsPacket>(alignedStart))
301 for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
302 pstore(&res[i], pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(i), ptmp0, pload<ResPacket>(&res[i])));
304 for (Index i = alignedStart;i<alignedSize;i+=ResPacketSize)
305 pstore(&res[i], pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(i), ptmp0, pload<ResPacket>(&res[i])));
309 for (Index i=alignedSize; i<size; ++i)
310 res[i] += cj.pmul(lhs0(i), pfirst(ptmp0));
320 }
while(Vectorizable);
321 #undef _EIGEN_ACCUMULATE_PACKETS
334 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
335 struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,
RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
337 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
340 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
341 && int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
342 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
343 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
344 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
347 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
348 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
349 typedef typename packet_traits<ResScalar>::type _ResPacket;
351 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
352 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
353 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
355 EIGEN_DONT_INLINE
static void run(
356 Index rows, Index cols,
357 const LhsMapper& lhs,
358 const RhsMapper& rhs,
359 ResScalar* res, Index resIncr,
363 template<
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
int Version>
364 EIGEN_DONT_INLINE
void general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>::run(
365 Index rows, Index cols,
366 const LhsMapper& lhs,
367 const RhsMapper& rhs,
368 ResScalar* res, Index resIncr,
371 eigen_internal_assert(rhs.stride()==1);
373 #ifdef _EIGEN_ACCUMULATE_PACKETS
374 #error _EIGEN_ACCUMULATE_PACKETS has already been defined
377 #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\
378 RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0); \
379 ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Alignment0>(j), b, ptmp0); \
380 ptmp1 = pcj.pmadd(lhs1.template load<LhsPacket, Alignment13>(j), b, ptmp1); \
381 ptmp2 = pcj.pmadd(lhs2.template load<LhsPacket, Alignment2>(j), b, ptmp2); \
382 ptmp3 = pcj.pmadd(lhs3.template load<LhsPacket, Alignment13>(j), b, ptmp3); }
384 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
385 conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
387 typedef typename LhsMapper::VectorMapper LhsScalars;
389 enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 };
390 const Index rowsAtOnce = 4;
391 const Index peels = 2;
392 const Index RhsPacketAlignedMask = RhsPacketSize-1;
393 const Index LhsPacketAlignedMask = LhsPacketSize-1;
394 const Index depth = cols;
395 const Index lhsStride = lhs.stride();
400 Index alignedStart = rhs.firstAligned(depth);
401 Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0;
402 const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1;
404 const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0;
405 Index alignmentPattern = alignmentStep==0 ? AllAligned
406 : alignmentStep==(LhsPacketSize/2) ? EvenAligned
410 const Index lhsAlignmentOffset = lhs.firstAligned(depth);
411 const Index rhsAlignmentOffset = rhs.firstAligned(rows);
416 if( (
sizeof(LhsScalar)!=
sizeof(RhsScalar)) ||
417 (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
418 (rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
422 alignmentPattern = NoneAligned;
424 else if(LhsPacketSize > 4)
427 alignmentPattern = NoneAligned;
429 else if (LhsPacketSize>1)
433 while (skipRows<LhsPacketSize &&
434 alignedStart != ((lhsAlignmentOffset + alignmentStep*skipRows)%LhsPacketSize))
436 if (skipRows==LhsPacketSize)
439 alignmentPattern = NoneAligned;
444 skipRows = (std::min)(skipRows,Index(rows));
453 else if(Vectorizable)
457 alignmentPattern = AllAligned;
460 const Index offset1 = (FirstAligned && alignmentStep==1?3:1);
461 const Index offset3 = (FirstAligned && alignmentStep==1?1:3);
463 Index rowBound = ((rows-skipRows)/rowsAtOnce)*rowsAtOnce + skipRows;
464 for (Index i=skipRows; i<rowBound; i+=rowsAtOnce)
467 EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
468 ResScalar tmp1 = ResScalar(0), tmp2 = ResScalar(0), tmp3 = ResScalar(0);
471 const LhsScalars lhs0 = lhs.getVectorMapper(i+0, 0), lhs1 = lhs.getVectorMapper(i+offset1, 0),
472 lhs2 = lhs.getVectorMapper(i+2, 0), lhs3 = lhs.getVectorMapper(i+offset3, 0);
477 ResPacket ptmp0 = pset1<ResPacket>(ResScalar(0)), ptmp1 = pset1<ResPacket>(ResScalar(0)),
478 ptmp2 = pset1<ResPacket>(ResScalar(0)), ptmp3 = pset1<ResPacket>(ResScalar(0));
482 for (Index j=0; j<alignedStart; ++j)
484 RhsScalar b = rhs(j, 0);
485 tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
486 tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
489 if (alignedSize>alignedStart)
491 switch(alignmentPattern)
494 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
498 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
503 Index j = alignedStart;
512 LhsPacket A01, A02, A03, A11, A12, A13;
513 A01 = lhs1.template load<LhsPacket, Aligned>(alignedStart-1);
514 A02 = lhs2.template load<LhsPacket, Aligned>(alignedStart-2);
515 A03 = lhs3.template load<LhsPacket, Aligned>(alignedStart-3);
517 for (; j<peeledSize; j+=peels*RhsPacketSize)
519 RhsPacket b = rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0);
520 A11 = lhs1.template load<LhsPacket, Aligned>(j-1+LhsPacketSize); palign<1>(A01,A11);
521 A12 = lhs2.template load<LhsPacket, Aligned>(j-2+LhsPacketSize); palign<2>(A02,A12);
522 A13 = lhs3.template load<LhsPacket, Aligned>(j-3+LhsPacketSize); palign<3>(A03,A13);
524 ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), b, ptmp0);
525 ptmp1 = pcj.pmadd(A01, b, ptmp1);
526 A01 = lhs1.template load<LhsPacket, Aligned>(j-1+2*LhsPacketSize); palign<1>(A11,A01);
527 ptmp2 = pcj.pmadd(A02, b, ptmp2);
528 A02 = lhs2.template load<LhsPacket, Aligned>(j-2+2*LhsPacketSize); palign<2>(A12,A02);
529 ptmp3 = pcj.pmadd(A03, b, ptmp3);
530 A03 = lhs3.template load<LhsPacket, Aligned>(j-3+2*LhsPacketSize); palign<3>(A13,A03);
532 b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load<RhsPacket, Aligned>(0);
533 ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j+LhsPacketSize), b, ptmp0);
534 ptmp1 = pcj.pmadd(A11, b, ptmp1);
535 ptmp2 = pcj.pmadd(A12, b, ptmp2);
536 ptmp3 = pcj.pmadd(A13, b, ptmp3);
539 for (; j<alignedSize; j+=RhsPacketSize)
544 for (Index j = alignedStart; j<alignedSize; j+=RhsPacketSize)
548 tmp0 += predux(ptmp0);
549 tmp1 += predux(ptmp1);
550 tmp2 += predux(ptmp2);
551 tmp3 += predux(ptmp3);
557 for (Index j=alignedSize; j<depth; ++j)
559 RhsScalar b = rhs(j, 0);
560 tmp0 += cj.pmul(lhs0(j),b); tmp1 += cj.pmul(lhs1(j),b);
561 tmp2 += cj.pmul(lhs2(j),b); tmp3 += cj.pmul(lhs3(j),b);
563 res[i*resIncr] += alpha*tmp0;
564 res[(i+offset1)*resIncr] += alpha*tmp1;
565 res[(i+2)*resIncr] += alpha*tmp2;
566 res[(i+offset3)*resIncr] += alpha*tmp3;
571 Index start = rowBound;
574 for (Index i=start; i<end; ++i)
576 EIGEN_ALIGN_MAX ResScalar tmp0 = ResScalar(0);
577 ResPacket ptmp0 = pset1<ResPacket>(tmp0);
578 const LhsScalars lhs0 = lhs.getVectorMapper(i, 0);
581 for (Index j=0; j<alignedStart; ++j)
582 tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
584 if (alignedSize>alignedStart)
587 if (lhs0.template aligned<LhsPacket>(alignedStart))
588 for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
589 ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Aligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
591 for (Index j = alignedStart;j<alignedSize;j+=RhsPacketSize)
592 ptmp0 = pcj.pmadd(lhs0.template load<LhsPacket, Unaligned>(j), rhs.getVectorMapper(j, 0).template load<RhsPacket, Aligned>(0), ptmp0);
593 tmp0 += predux(ptmp0);
598 for (Index j=alignedSize; j<depth; ++j)
599 tmp0 += cj.pmul(lhs0(j), rhs(j, 0));
600 res[i*resIncr] += alpha*tmp0;
610 }
while(Vectorizable);
612 #undef _EIGEN_ACCUMULATE_PACKETS
619 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H
Definition: Constants.h:320
Definition: Constants.h:228
Definition: Constants.h:322
Definition: Eigen_Colamd.h:54
Definition: Constants.h:235