Add repmat and repmat3d, Fix sqrt with complex.

This commit is contained in:
sunwen
2023-11-27 09:47:59 +08:00
parent 3d68171394
commit fc9b6be9e8
4 changed files with 161 additions and 2 deletions

View File

@@ -163,8 +163,8 @@ TEST_F(Function1D_Cuda_Test, sqrt)
deviceMatrix = hostMatrix.toDeviceMatrix();
result1 = Aurora::sqrt(hostMatrix);
result2 = Aurora::sqrt(deviceMatrix).toHostMatrix();
EXPECT_EQ(result2.getDataSize(), 4);
EXPECT_EQ(result2.getValueType(), Aurora::Complex);
EXPECT_EQ(result2.getDataSize(), result1.getDataSize());
EXPECT_EQ(result2.getValueType(), result1.getValueType());
for(size_t i=0; i<result1.getDataSize() * result1.getValueType(); ++i)
{
EXPECT_EQ(result1[i], result2[i]);
@@ -222,3 +222,51 @@ TEST_F(Function1D_Cuda_Test, sign)
EXPECT_EQ(result1[i], result2[i]);
}
}
TEST_F(Function1D_Cuda_Test, repmat)
{
Aurora::Matrix hostMatrix = Aurora::Matrix::fromRawData(new float[8]{1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8}, 2,4);
Aurora::CudaMatrix deviceMatrix = hostMatrix.toDeviceMatrix();
auto result1 = Aurora::repmat(hostMatrix,3,6);
auto result2 = Aurora::repmat(deviceMatrix,3,6).toHostMatrix();
EXPECT_EQ(result2.getDataSize(), 8 * 3 * 6);
EXPECT_EQ(result2.getValueType(), Aurora::Normal);
for(size_t i=0; i<result1.getDataSize(); ++i)
{
EXPECT_EQ(result1[i], result2[i]);
}
hostMatrix = Aurora::Matrix::fromRawData(new float[8]{1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8}, 2,2,1,Aurora::Complex);
deviceMatrix = hostMatrix.toDeviceMatrix();
result1 = Aurora::repmat(hostMatrix, 4, 8);
result2 = Aurora::repmat(deviceMatrix, 4, 8).toHostMatrix();
EXPECT_EQ(result2.getDataSize(), 4 * 4 * 8);
EXPECT_EQ(result2.getValueType(), Aurora::Complex);
for(size_t i=0; i<result1.getDataSize() * result1.getValueType(); ++i)
{
EXPECT_EQ(result1[i], result2[i]);
}
hostMatrix = Aurora::Matrix::fromRawData(new float[12]{1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10,11,12}, 3, 4, 1,Aurora::Normal);
deviceMatrix = hostMatrix.toDeviceMatrix();
result1 = Aurora::repmat(hostMatrix, 4, 8, 3);
result2 = Aurora::repmat(deviceMatrix, 4, 8, 3).toHostMatrix();
EXPECT_EQ(result2.getDataSize(), 3 * 4 * 4 * 8 * 3);
EXPECT_EQ(result2.getValueType(), Aurora::Normal);
for(size_t i=0; i<result1.getDataSize() * result1.getValueType(); ++i)
{
EXPECT_EQ(result1[i], result2[i]);
}
hostMatrix = Aurora::Matrix::fromRawData(new float[12]{1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10,11,12}, 3, 2, 1,Aurora::Complex);
deviceMatrix = hostMatrix.toDeviceMatrix();
result1 = Aurora::repmat(hostMatrix, 4, 8, 3);
result2 = Aurora::repmat(deviceMatrix, 4, 8, 3).toHostMatrix();
EXPECT_EQ(result2.getDataSize(), 3 * 2 * 4 * 8 * 3);
EXPECT_EQ(result2.getValueType(), Aurora::Complex);
for(size_t i=0; i<result1.getDataSize() * result1.getValueType(); ++i)
{
EXPECT_EQ(result1[i], result2[i]);
}
}