Fix sort and min Unit test

This commit is contained in:
kradchen
2023-12-07 15:46:36 +08:00
parent 57fc2219ce
commit a65ee38196
2 changed files with 17 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
#include "AuroraDefs.h"
#include "CudaMatrix.h"
#include "Function1D.h"
#include "Function1D.cuh"
#include "Matrix.h"
#include <Function2D.cuh>
#include <cfloat>
@@ -272,8 +273,9 @@ __global__ void minColKernel(float* aInputData, float* aOutput, unsigned int aCo
}
// 规约最前面一段
for (int i = blockDim.x/2; i >0; i>>=1) {
if (threadIdx.x < i) {
shared_data[threadIdx.x] += fminf(shared_data[threadIdx.x], shared_data[threadIdx.x+i]);
shared_data[threadIdx.x] = fminf(shared_data[threadIdx.x], shared_data[i + threadIdx.x]);
}
__syncthreads();
}
@@ -302,7 +304,7 @@ __global__ void minRowKernel(float* aInputData, float* aOutput,unsigned int aCol
// 规约最前面一段
for (int i = blockDim.x/2; i >0; i>>=1) {
if (threadIdx.x < i) {
shared_data[threadIdx.x] += fminf(shared_data[threadIdx.x], shared_data[threadIdx.x+i]);
shared_data[threadIdx.x] = fminf(shared_data[threadIdx.x], shared_data[threadIdx.x+i]);
}
__syncthreads();
}
@@ -876,7 +878,12 @@ CudaMatrix Aurora::sort(CudaMatrix &&aMatrix,FunctionDirection direction)
case Column:
{
int rowElementCount = aMatrix.getDimSize(1);
// softKernel<<<rowElementCount,1>>>(data,colElementCount);
for (size_t i = 0; i < rowElementCount; i++)
{
thrust::sort(thrust::device, data+i*colElementCount,
data+(i+1)*colElementCount);
}
return aMatrix;
}
default:

View File

@@ -426,14 +426,14 @@ TEST_F(Function2D_Cuda_Test, sum)
{
//
{
float *dataB = Aurora::random(4096*50000);
// float* dataB = new float[4096*50000];
// for (size_t i = 0; i < 4096*50000; i++)
// {
// dataB[i] = (float)(i/4096);
// }
// float *dataB = Aurora::random(4096*50000);
float* dataB = new float[4096*5000];
for (size_t i = 0; i < 4096*5000; i++)
{
dataB[i] = (i%2==0?1.0f:0.0f);
}
B = Aurora::Matrix::fromRawData(dataB, 4096, 50000);
B = Aurora::Matrix::fromRawData(dataB, 4096, 5000);
// B = Aurora::Matrix::fromRawData(dataB, 200, 200);
auto dD = B.toDeviceMatrix();