From 3ea6c840871d38b910ca9a4f272e1df13c6bb4cc Mon Sep 17 00:00:00 2001 From: kradchen Date: Tue, 18 Mar 2025 16:00:10 +0800 Subject: [PATCH] feat: replace CudaMatrix compare function ( deprecated in new cuda version ) with lambda --- src/CudaMatrixPrivate.cu | 218 ++++++++++++++------------------------- test/CudaMatrix_Test.cpp | 50 ++++++++- 2 files changed, 129 insertions(+), 139 deletions(-) diff --git a/src/CudaMatrixPrivate.cu b/src/CudaMatrixPrivate.cu index 26c6813..cc42f5a 100644 --- a/src/CudaMatrixPrivate.cu +++ b/src/CudaMatrixPrivate.cu @@ -5,129 +5,11 @@ #include #include + #include "AuroraDefs.h" #include "AuroraThrustIterator.cuh" using namespace thrust::placeholders; -struct PowOp: public thrust::unary_function{ - float exponent; - PowOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return powf(x, exponent); - } -}; - -struct CompareGOp: public thrust::unary_function{ - float exponent; - CompareGOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return (exponent{ - float exponent; - CompareGEOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return (exponent<=x?1.0:.0); - } -}; - -struct CompareEOp: public thrust::unary_function{ - float exponent; - CompareEOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return (exponent==x?1.0:.0); - } -}; - -struct CompareNEOp: public thrust::unary_function{ - float exponent; - CompareNEOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return (exponent!=x?1.0:.0); - } -}; - -struct CompareLOp: public thrust::unary_function{ - float exponent; - CompareLOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return (exponent>x?1.0:.0); - } -}; - -struct CompareLEOp: public thrust::unary_function{ - float exponent; - CompareLEOp(float v):exponent(v) {} - void setExponent(float v){ - exponent = v; - } - - __host__ __device__ - float operator()(const float& x) { - return (exponent>=x?1.0:.0); - } -}; - - -struct CompareAGOp{ - __host__ __device__ - float operator()(const float& x,const float& y) { - return x>y?1:0; - } -}; - -struct CompareAGEOp{ - __host__ __device__ - float operator()(const float& x,const float& y) { - return x>=y?1:0; - } -}; - -struct CompareAEOp{ - __host__ __device__ - float operator()(const float& x,const float& y) { - return x==y?1:0; - } -}; - -struct CompareANEOp{ - __host__ __device__ - float operator()(const float& x,const float& y) { - return x!=y?1:0; - } -}; typedef thrust::complex complexf; @@ -651,29 +533,51 @@ void unaryPow(float* in1, float N,float* out, unsigned long length){ thrust::transform(thrust::device,in1,in1+length,out,op); return; } - thrust::transform(thrust::device,in1,in1+length,out,PowOp(N)); + auto lambdaPow = [N] __host__ __device__(float x) { + return powf(x,N); + }; + thrust::transform(thrust::device,in1,in1+length,out,lambdaPow); } void unaryCompare(float* in1, const float& in2, float* out, unsigned long length, int type){ + switch (type) { case G: - thrust::transform(thrust::device,in1,in1+length,out,CompareGOp(in2)); + thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x) + { + return in2 < x ? 1.0 : .0; + }); break; case GE: - thrust::transform(thrust::device,in1,in1+length,out,CompareGEOp(in2)); + thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x) + { + return in2 <= x ? 1.0 : .0; + }); break; case E: - thrust::transform(thrust::device,in1,in1+length,out,CompareEOp(in2)); + thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x) + { + return in2 == x ? 1.0 : .0; + }); break; case NE: - thrust::transform(thrust::device,in1,in1+length,out,CompareNEOp(in2)); + thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x) + { + return in2 != x ? 1.0 : .0; + }); break; case LE: - thrust::transform(thrust::device,in1,in1+length,out,CompareLEOp(in2)); + thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x) + { + return in2 >= x ? 1.0 : .0; + }); break; case L: - thrust::transform(thrust::device,in1,in1+length,out,CompareLOp(in2)); + thrust::transform(thrust::device,in1,in1+length,out,[in2]__host__ __device__(const float &x) + { + return in2 > x ? 1.0 : .0; + }); break; default: break; @@ -683,51 +587,89 @@ void unaryCompare(const float& in1, float* in2, float* out, unsigned long length switch (type) { case G: - thrust::transform(thrust::device,in2,in2+length,out,CompareLOp(in1)); + thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x) + { + return in1 > x ? 1.0 : .0; + }); break; case GE: - thrust::transform(thrust::device,in2,in2+length,out,CompareLEOp(in1)); + thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x) + { + return in1 >= x ? 1.0 : .0; + }); break; case E: - thrust::transform(thrust::device,in2,in2+length,out,CompareEOp(in1)); + thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x) + { + return in1 == x ? 1.0 : .0; + }); break; case NE: - thrust::transform(thrust::device,in2,in2+length,out,CompareNEOp(in1)); + thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x) + { + return in1 != x ? 1.0 : .0; + }); break; case LE: - thrust::transform(thrust::device,in2,in2+length,out,CompareGEOp(in1)); + thrust::transform(thrust::device,in2,in2+length,out, [in1] __host__ __device__ (const float &x) + { + return in1 <= x ? 1.0 : .0; + }); break; case L: - thrust::transform(thrust::device,in2,in2+length,out,CompareGOp(in1)); + thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x) + { + return in1 < x ? 1.0 : .0; + }); break; default: break; } + } void unaryCompare(float* in1, float* in2, float* out, unsigned long length, int type){ switch (type) { case G: - thrust::transform(thrust::device,in1,in1+length,in2,out,CompareAGOp()); + thrust::transform(thrust::device,in1,in1+length,in2,out, []__host__ __device__(float x, float y) + { + return x > y ? 1. : .0; + }); break; case GE: - thrust::transform(thrust::device,in1,in1+length,in2,out,CompareAGEOp()); + thrust::transform(thrust::device,in1,in1+length,in2,out,[]__host__ __device__(float x, float y) + { + return x >= y ? 1. : .0; + }); break; case E: - thrust::transform(thrust::device,in1,in1+length,in2,out,CompareAEOp()); + thrust::transform(thrust::device,in1,in1+length,in2,out,[]__host__ __device__(float x, float y) + { + return x == y ? 1. : .0; + }); break; case NE: - thrust::transform(thrust::device,in1,in1+length,in2,out,CompareANEOp()); + thrust::transform(thrust::device,in1,in1+length,in2,out, []__host__ __device__(float x, float y) + { + return x != y ? 1. : .0; + }); break; case LE: - thrust::transform(thrust::device,in2,in2+length,in1,out,CompareAGEOp()); + thrust::transform(thrust::device,in1,in1+length,in2,out,[]__host__ __device__ (float x, float y) + { + return x <= y ? 1. : .0; + }); break; case L: - thrust::transform(thrust::device,in2,in2+length,in1,out,CompareAGOp()); + thrust::transform(thrust::device,in1,in1+length,in2,out, [] __host__ __device__ (float x, float y) + { + return x < y ? 1. : .0; + }); break; default: break; } + } void thrustFill(float* aBegin, float* aEnd, float aValue) diff --git a/test/CudaMatrix_Test.cpp b/test/CudaMatrix_Test.cpp index 4468bd0..9a01f71 100644 --- a/test/CudaMatrix_Test.cpp +++ b/test/CudaMatrix_Test.cpp @@ -2558,7 +2558,55 @@ TEST_F(CudaMatrix_Test, MatrixCompare){ } { auto R= (9!=B); - auto dhR = (9!=dB).toHostMatrix(); + auto dhR = (dB!=9).toHostMatrix(); + for (size_t i = 0; i < 1000; i++) + { + EXPECT_FLOAT_EQ(R[i],dhR[i]); + } + } + { + auto R= (99).toHostMatrix(); + for (size_t i = 0; i < 1000; i++) + { + EXPECT_FLOAT_EQ(R[i],dhR[i]); + } + } + { + auto R= (9>B); + auto dhR = (dB<9).toHostMatrix(); + for (size_t i = 0; i < 1000; i++) + { + EXPECT_FLOAT_EQ(R[i],dhR[i]); + } + } + { + auto R= (9<=B); + auto dhR = (dB>=9).toHostMatrix(); + for (size_t i = 0; i < 1000; i++) + { + EXPECT_FLOAT_EQ(R[i],dhR[i]); + } + } + { + auto R= (9>=B); + auto dhR = (dB<=9).toHostMatrix(); + for (size_t i = 0; i < 1000; i++) + { + EXPECT_FLOAT_EQ(R[i],dhR[i]); + } + } + { + auto R= (9==B); + auto dhR = (dB == 9).toHostMatrix(); + for (size_t i = 0; i < 1000; i++) + { + EXPECT_FLOAT_EQ(R[i],dhR[i]); + } + } + { + auto R= (9!=B); + auto dhR = (dB!=9).toHostMatrix(); for (size_t i = 0; i < 1000; i++) { EXPECT_FLOAT_EQ(R[i],dhR[i]);