Fix mv row mode bug.

This commit is contained in:
kradchen
2023-12-19 13:09:44 +08:00
parent fd7c71f7e9
commit cd10ebb5e8

View File

@@ -450,28 +450,31 @@ void unarySubmv(float *aMatrixIn1, float *aVectorIn2, float *aMatrixOut,
Aurora::RowWiseIterator<float> rowIter_Begin(aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<float> rowIter_End(aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<float> rowVectorIter(aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<float> outIter(aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const float& x, const float& y){
return direction==0?x-y:y-x;
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
else if (aValType1 == Aurora::Complex){
Aurora::RowWiseIterator<complexf> rowIter_Begin((complexf*)aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<complexf> rowIter_End((complexf*)aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<float> rowVectorIter(aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<complexf> outIter((complexf*)aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const complexf& x, const float& y){
return direction==0?complexf(x.real()-y,x.imag()):complexf(y-x.real(),-x.imag());
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,(complexf*)aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
else{
Aurora::RowWiseIterator<float> rowIter_Begin(aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<float> rowIter_End(aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<complexf> rowVectorIter((complexf*)aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<complexf> outIter((complexf*)aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const float& x, const complexf& y){
return direction==0?complexf(x-y.real(),-y.imag()):complexf(y.real()-x,y.imag());
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,(complexf*)aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
}
}
@@ -587,19 +590,21 @@ void unaryDivmv(float *aMatrixIn1, float *aVectorIn2, float *aMatrixOut,
Aurora::RowWiseIterator<float> rowIter_Begin(aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<float> rowIter_End(aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<float> rowVectorIter(aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<float> outIter(aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const float& x, const float& y){
return direction==0?x/y:y/x;
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
else{
Aurora::RowWiseIterator<complexf> rowIter_Begin((complexf*)aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<complexf> rowIter_End((complexf*)aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<complexf> rowVectorIter((complexf*)aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<complexf> outIter((complexf*)aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const complexf& x, const complexf& y){
return direction==0?x/y:y/x;
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,(complexf*)aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
}
@@ -607,21 +612,23 @@ void unaryDivmv(float *aMatrixIn1, float *aVectorIn2, float *aMatrixOut,
Aurora::RowWiseIterator<complexf> rowIter_Begin((complexf*)aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<complexf> rowIter_End((complexf*)aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<float> rowVectorIter(aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<complexf> outIter((complexf*)aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const complexf& x, const float& y){
complexf v (y,0);
return direction==0?x/v:v/x;
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,(complexf*)aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
else{
Aurora::RowWiseIterator<float> rowIter_Begin(aMatrixIn1,aVectorLength,colElementCount,0);
Aurora::RowWiseIterator<float> rowIter_End(aMatrixIn1,aVectorLength,colElementCount,aMatrixLength);
Aurora::LoopVectorIterator<complexf> rowVectorIter((complexf*)aVectorIn2,aVectorLength);
Aurora::RowWiseIterator<complexf> outIter((complexf*)aMatrixOut,aVectorLength,colElementCount,0);
auto lambda = [=] __device__(const float& x, const complexf& y){
complexf v (x,0);
return direction==0?v/y:y/v;
};
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,(complexf*)aMatrixOut,lambda);
thrust::transform(thrust::device, rowIter_Begin,rowIter_End,rowVectorIter,outIter,lambda);
}
}
}