feat: Add cuda function in recon art.

This commit is contained in:
sunwen
2024-12-26 15:54:18 +08:00
parent 4daf8e0c0b
commit 577294088b

View File

@@ -10,12 +10,15 @@
#include "Function3D.h" #include "Function3D.h"
#include "config/config.h" #include "config/config.h"
#include "transmissionReconstruction/reconstruction/buildMatrix/buildMatrix.h" #include "transmissionReconstruction/reconstruction/buildMatrix/buildMatrix.cuh"
#include "transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h" #include "transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h"
#include "CudaEnvInit.h" #include "CudaEnvInit.h"
using namespace Aurora; using namespace Aurora;
using solveParameterIteratorFunctionType = std::vector<std::vector<Aurora::Matrix>> (*)(Aurora::Sparse M, Aurora::Matrix &b,
const Aurora::Matrix &dims, bool oneIter, bool nonNeg, int aDevice);
using slownessToSOSFunctionType = Matrix (*)(Aurora::Matrix & aVF1, float aSOS_IN_WATER);
namespace Recon { namespace Recon {
Aurora::Matrix calculateMinimalMaximalTransducerPositions( Aurora::Matrix calculateMinimalMaximalTransducerPositions(
const Aurora::Matrix &aMSenderList, const Aurora::Matrix &aMReceiverList) { const Aurora::Matrix &aMSenderList, const Aurora::Matrix &aMReceiverList) {
@@ -189,7 +192,8 @@ namespace Recon {
BuildMatrixResult buildMatrixR; BuildMatrixResult buildMatrixR;
for(int iter=1; iter<=numIter; ++iter) for(int iter=1; iter<=numIter; ++iter)
{ {
buildMatrixR = buildMatrix(senderList, receiverList, res, dims, bentRecon && (iter!=1), potentialMap); auto resDevice = res.toDeviceMatrix();
buildMatrixR = buildMatrix(senderList.toDeviceMatrix(), receiverList.toDeviceMatrix(), resDevice, dims.toDeviceMatrix(), bentRecon && (iter!=1), potentialMap.toDeviceMatrix());
if(!data.isNull() && bentRecon && iter != numIter) if(!data.isNull() && bentRecon && iter != numIter)
{ {
//与默认配置bentRecon不符暂不实现todo //与默认配置bentRecon不符暂不实现todo
@@ -222,22 +226,15 @@ namespace Recon {
{ {
allHitMaps.push_back(buildMatrixR.hitmap); allHitMaps.push_back(buildMatrixR.hitmap);
} }
#pragma omp parallel for num_threads(2) if(!data.isNull())
for (int i =0; i<2; i++){ {
if (i ==0){ Matrix sosValue = solveParameterIterator(buildMatrixR.M, b, dims, false, transParams::nonNeg)[0][0];
if(!data.isNull()) result.outSOS = slownessToSOS(sosValue, SOS_IN_WATER) ;
{ }
Matrix sosValue = solveParameterIterator(buildMatrixR.M, b, dims, false, transParams::nonNeg)[0][0]; if(!dataAtt.isNull())
result.outSOS = slownessToSOS(sosValue, SOS_IN_WATER) ; {
} Matrix attValue = solveParameterIterator(buildMatrixR.M, bAtt, dims, false, transParams::nonNeg)[0][0];
} result.outATT = attValue/100 ;
else{
if(!dataAtt.isNull())
{
Matrix attValue = solveParameterIterator(buildMatrixR.M, bAtt, dims, false, transParams::nonNeg,1)[0][0];
result.outATT = attValue/100 ;
}
}
} }
} }