Fix UnitTest add cudamatrix add and mul

This commit is contained in:
kradchen
2023-11-01 14:31:29 +08:00
parent fe0abf8ee6
commit 029b86013e
13 changed files with 1108 additions and 617 deletions

View File

@@ -5,7 +5,7 @@
#include <thrust/execution_policy.h>
using namespace thrust::placeholders;
struct PowOperator{
struct PowOperator: public thrust::unary_function<float, float>{
float exponent;
PowOperator(float v):exponent(v) {}
void setExponent(float v){
@@ -25,7 +25,7 @@ void unaryAdd(float* in1, float* in2, float* out, unsigned long length)
void unaryAdd(float* in1, const float& in2, float* out, unsigned long length)
{
thrust::transform(thrust::device,in1,in1+length,out,in2*_1);
thrust::transform(thrust::device,in1,in1+length,out,in2 + _1);
}
void unaryMul(float* in1, float* in2, float* out, unsigned long length)
@@ -34,6 +34,11 @@ void unaryMul(float* in1, float* in2, float* out, unsigned long length)
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
}
void unaryMul(float* in1, const float& in2, float* out, unsigned long length)
{
thrust::transform(thrust::device,in1, in1+length, out, in2 * _1);
}
void unaryNeg(float* in1, float* out, unsigned long length){
thrust::negate<float> op;
thrust::transform(thrust::device,in1,in1+length,out,op);
@@ -49,6 +54,23 @@ void unaryDiv(float* in1, float* in2, float* out, unsigned long length){
thrust::transform(thrust::device,in1,in1+length,in2,out,op);
}
void unarySub(const float& in1, float* in2, float* out, unsigned long length){
thrust::transform(thrust::device,in2,in2+length,out,in1-_1);
}
void unaryDiv(const float& in1, float* in2, float* out, unsigned long length){
thrust::transform(thrust::device,in2,in2+length,out,in1/_1);
}
void unarySub(float* in1, const float& in2, float* out, unsigned long length){
thrust::transform(thrust::device,in1,in1+length,out,_1-in2);
}
void unaryDiv(float* in1, const float& in2, float* out, unsigned long length){
thrust::transform(thrust::device,in1,in1+length,out,_1/in2);
}
void unaryPow(float* in1, float N,float* out, unsigned long length){
if (N == 0.0f)
{
@@ -65,7 +87,6 @@ void unaryPow(float* in1, float N,float* out, unsigned long length){
thrust::transform(thrust::device,in1,in1+length,out,op);
return;
}
thrust::transform(thrust::device,in1,in1+length,out,powf(_1,N));
thrust::transform(thrust::device,in1,in1+length,out,PowOperator(N));
}