Add prod and unittest.

This commit is contained in:
sunwen
2023-12-11 15:01:24 +08:00
parent 8b58d05d90
commit bd4a27a17b
3 changed files with 155 additions and 1 deletions

View File

@@ -807,10 +807,55 @@ TEST_F(Function2D_Cuda_Test, dot) {
Aurora::CudaMatrix matrixDevice2 = matrixHost2.toDeviceMatrix();
auto result1 = Aurora::dot(matrixHost1, matrixHost2);
auto result2 = Aurora::dot(matrixDevice1, matrixDevice2).toHostMatrix();
std::cout<< result1.getDataSize();
ASSERT_FLOAT_EQ(result1.getDataSize(), result2.getDataSize());
for (size_t i = 0; i < result1.getDataSize(); i++)
{
EXPECT_FLOAT_AE(result1[i], result2[i]);
}
}
TEST_F(Function2D_Cuda_Test, prod) {
auto matrixHost = Aurora::Matrix::fromRawData(new float[20], 4,5);
for(unsigned int i=0; i<20;++i)
{
matrixHost[i] = i + 1;
}
auto matrixDevice = matrixHost.toDeviceMatrix();
auto result1 = Aurora::prod(matrixHost);
auto result2 = Aurora::prod(matrixDevice).toHostMatrix();
ASSERT_FLOAT_EQ(result1.getDataSize(), result2.getDataSize());
for (size_t i = 0; i < result1.getDataSize(); i++)
{
EXPECT_FLOAT_AE(result1[i], result2[i]);
}
matrixHost = Aurora::Matrix::fromRawData(new float[20], 20,1);
for(unsigned int i=0; i<20;++i)
{
matrixHost[i] = i + 1;
}
matrixDevice = matrixHost.toDeviceMatrix();
result1 = Aurora::prod(matrixHost);
result2 = Aurora::prod(matrixDevice).toHostMatrix();
ASSERT_FLOAT_EQ(result1.getDataSize(), result2.getDataSize());
for (size_t i = 0; i < result1.getDataSize(); i++)
{
EXPECT_FLOAT_AE(result1[i], result2[i]);
}
auto matrixHostComplex = Aurora::Matrix::fromRawData(new float[40], 4,5, 1,Aurora::Complex);
for(unsigned int i=0; i<40;++i)
{
matrixHost[i] = i + 1;
}
matrixDevice = matrixHostComplex.toDeviceMatrix();
result1 = Aurora::prod(matrixHostComplex);
result2 = Aurora::prod(matrixDevice).toHostMatrix();
ASSERT_FLOAT_EQ(result1.getDataSize(), result2.getDataSize());
ASSERT_FLOAT_EQ(result1.getValueType(), result2.getValueType());
for (size_t i = 0; i < result1.getDataSize() * result1.getValueType(); i++)
{
EXPECT_FLOAT_AE(result1[i], result2[i]);
}
}