feat: replace CudaMatrix compare function ( deprecated in new cuda version ) with lambda

This commit is contained in:
kradchen
2025-03-18 16:00:10 +08:00
parent 4ba0d23d54
commit 3ea6c84087
2 changed files with 129 additions and 139 deletions

View File

@@ -5,129 +5,11 @@
#include <thrust/functional.h>
#include <thrust/execution_policy.h>
#include "AuroraDefs.h"
#include "AuroraThrustIterator.cuh"
using namespace thrust::placeholders;
struct PowOp: public thrust::unary_function<float, float>{
float exponent;
PowOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return powf(x, exponent);
}
};
struct CompareGOp: public thrust::unary_function<float, float>{
float exponent;
CompareGOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return (exponent<x?1.0:.0);
}
};
struct CompareGEOp: public thrust::unary_function<float, float>{
float exponent;
CompareGEOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return (exponent<=x?1.0:.0);
}
};
struct CompareEOp: public thrust::unary_function<float, float>{
float exponent;
CompareEOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return (exponent==x?1.0:.0);
}
};
struct CompareNEOp: public thrust::unary_function<float, float>{
float exponent;
CompareNEOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return (exponent!=x?1.0:.0);
}
};
struct CompareLOp: public thrust::unary_function<float, float>{
float exponent;
CompareLOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return (exponent>x?1.0:.0);
}
};
struct CompareLEOp: public thrust::unary_function<float, float>{
float exponent;
CompareLEOp(float v):exponent(v) {}
void setExponent(float v){
exponent = v;
}
__host__ __device__
float operator()(const float& x) {
return (exponent>=x?1.0:.0);
}
};
struct CompareAGOp{
__host__ __device__
float operator()(const float& x,const float& y) {
return x>y?1:0;
}
};
struct CompareAGEOp{
__host__ __device__
float operator()(const float& x,const float& y) {
return x>=y?1:0;
}
};
struct CompareAEOp{
__host__ __device__
float operator()(const float& x,const float& y) {
return x==y?1:0;
}
};
struct CompareANEOp{
__host__ __device__
float operator()(const float& x,const float& y) {
return x!=y?1:0;
}
};
typedef thrust::complex<float> complexf;
@@ -651,29 +533,51 @@ void unaryPow(float* in1, float N,float* out, unsigned long length){
thrust::transform(thrust::device,in1,in1+length,out,op);
return;
}
thrust::transform(thrust::device,in1,in1+length,out,PowOp(N));
auto lambdaPow = [N] __host__ __device__(float x) {
return powf(x,N);
};
thrust::transform(thrust::device,in1,in1+length,out,lambdaPow);
}
void unaryCompare(float* in1, const float& in2, float* out, unsigned long length, int type){
switch (type)
{
case G:
thrust::transform(thrust::device,in1,in1+length,out,CompareGOp(in2));
thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x)
{
return in2 < x ? 1.0 : .0;
});
break;
case GE:
thrust::transform(thrust::device,in1,in1+length,out,CompareGEOp(in2));
thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x)
{
return in2 <= x ? 1.0 : .0;
});
break;
case E:
thrust::transform(thrust::device,in1,in1+length,out,CompareEOp(in2));
thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x)
{
return in2 == x ? 1.0 : .0;
});
break;
case NE:
thrust::transform(thrust::device,in1,in1+length,out,CompareNEOp(in2));
thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x)
{
return in2 != x ? 1.0 : .0;
});
break;
case LE:
thrust::transform(thrust::device,in1,in1+length,out,CompareLEOp(in2));
thrust::transform(thrust::device,in1,in1+length,out,[in2] __host__ __device__ (const float &x)
{
return in2 >= x ? 1.0 : .0;
});
break;
case L:
thrust::transform(thrust::device,in1,in1+length,out,CompareLOp(in2));
thrust::transform(thrust::device,in1,in1+length,out,[in2]__host__ __device__(const float &x)
{
return in2 > x ? 1.0 : .0;
});
break;
default:
break;
@@ -683,51 +587,89 @@ void unaryCompare(const float& in1, float* in2, float* out, unsigned long length
switch (type)
{
case G:
thrust::transform(thrust::device,in2,in2+length,out,CompareLOp(in1));
thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x)
{
return in1 > x ? 1.0 : .0;
});
break;
case GE:
thrust::transform(thrust::device,in2,in2+length,out,CompareLEOp(in1));
thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x)
{
return in1 >= x ? 1.0 : .0;
});
break;
case E:
thrust::transform(thrust::device,in2,in2+length,out,CompareEOp(in1));
thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x)
{
return in1 == x ? 1.0 : .0;
});
break;
case NE:
thrust::transform(thrust::device,in2,in2+length,out,CompareNEOp(in1));
thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x)
{
return in1 != x ? 1.0 : .0;
});
break;
case LE:
thrust::transform(thrust::device,in2,in2+length,out,CompareGEOp(in1));
thrust::transform(thrust::device,in2,in2+length,out, [in1] __host__ __device__ (const float &x)
{
return in1 <= x ? 1.0 : .0;
});
break;
case L:
thrust::transform(thrust::device,in2,in2+length,out,CompareGOp(in1));
thrust::transform(thrust::device,in2,in2+length,out,[in1] __host__ __device__ (const float &x)
{
return in1 < x ? 1.0 : .0;
});
break;
default:
break;
}
}
void unaryCompare(float* in1, float* in2, float* out, unsigned long length, int type){
switch (type)
{
case G:
thrust::transform(thrust::device,in1,in1+length,in2,out,CompareAGOp());
thrust::transform(thrust::device,in1,in1+length,in2,out, []__host__ __device__(float x, float y)
{
return x > y ? 1. : .0;
});
break;
case GE:
thrust::transform(thrust::device,in1,in1+length,in2,out,CompareAGEOp());
thrust::transform(thrust::device,in1,in1+length,in2,out,[]__host__ __device__(float x, float y)
{
return x >= y ? 1. : .0;
});
break;
case E:
thrust::transform(thrust::device,in1,in1+length,in2,out,CompareAEOp());
thrust::transform(thrust::device,in1,in1+length,in2,out,[]__host__ __device__(float x, float y)
{
return x == y ? 1. : .0;
});
break;
case NE:
thrust::transform(thrust::device,in1,in1+length,in2,out,CompareANEOp());
thrust::transform(thrust::device,in1,in1+length,in2,out, []__host__ __device__(float x, float y)
{
return x != y ? 1. : .0;
});
break;
case LE:
thrust::transform(thrust::device,in2,in2+length,in1,out,CompareAGEOp());
thrust::transform(thrust::device,in1,in1+length,in2,out,[]__host__ __device__ (float x, float y)
{
return x <= y ? 1. : .0;
});
break;
case L:
thrust::transform(thrust::device,in2,in2+length,in1,out,CompareAGOp());
thrust::transform(thrust::device,in1,in1+length,in2,out, [] __host__ __device__ (float x, float y)
{
return x < y ? 1. : .0;
});
break;
default:
break;
}
}
void thrustFill(float* aBegin, float* aEnd, float aValue)

View File

@@ -2558,7 +2558,55 @@ TEST_F(CudaMatrix_Test, MatrixCompare){
}
{
auto R= (9!=B);
auto dhR = (9!=dB).toHostMatrix();
auto dhR = (dB!=9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);
}
}
{
auto R= (9<B);
auto dhR = (dB>9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);
}
}
{
auto R= (9>B);
auto dhR = (dB<9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);
}
}
{
auto R= (9<=B);
auto dhR = (dB>=9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);
}
}
{
auto R= (9>=B);
auto dhR = (dB<=9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);
}
}
{
auto R= (9==B);
auto dhR = (dB == 9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);
}
}
{
auto R= (9!=B);
auto dhR = (dB!=9).toHostMatrix();
for (size_t i = 0; i < 1000; i++)
{
EXPECT_FLOAT_EQ(R[i],dhR[i]);