From 577294088b50a7abaa44e3861e8e586ecfa32568 Mon Sep 17 00:00:00 2001 From: sunwen Date: Thu, 26 Dec 2024 15:54:18 +0800 Subject: [PATCH] feat: Add cuda function in recon art. --- .../reconstruction/reconstruction.cpp | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/transmissionReconstruction/reconstruction/reconstruction.cpp b/src/transmissionReconstruction/reconstruction/reconstruction.cpp index 0a57737..3f7f333 100644 --- a/src/transmissionReconstruction/reconstruction/reconstruction.cpp +++ b/src/transmissionReconstruction/reconstruction/reconstruction.cpp @@ -10,12 +10,15 @@ #include "Function3D.h" #include "config/config.h" -#include "transmissionReconstruction/reconstruction/buildMatrix/buildMatrix.h" +#include "transmissionReconstruction/reconstruction/buildMatrix/buildMatrix.cuh" #include "transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h" #include "CudaEnvInit.h" using namespace Aurora; +using solveParameterIteratorFunctionType = std::vector> (*)(Aurora::Sparse M, Aurora::Matrix &b, + const Aurora::Matrix &dims, bool oneIter, bool nonNeg, int aDevice); +using slownessToSOSFunctionType = Matrix (*)(Aurora::Matrix & aVF1, float aSOS_IN_WATER); namespace Recon { Aurora::Matrix calculateMinimalMaximalTransducerPositions( const Aurora::Matrix &aMSenderList, const Aurora::Matrix &aMReceiverList) { @@ -189,7 +192,8 @@ namespace Recon { BuildMatrixResult buildMatrixR; for(int iter=1; iter<=numIter; ++iter) { - buildMatrixR = buildMatrix(senderList, receiverList, res, dims, bentRecon && (iter!=1), potentialMap); + auto resDevice = res.toDeviceMatrix(); + buildMatrixR = buildMatrix(senderList.toDeviceMatrix(), receiverList.toDeviceMatrix(), resDevice, dims.toDeviceMatrix(), bentRecon && (iter!=1), potentialMap.toDeviceMatrix()); if(!data.isNull() && bentRecon && iter != numIter) { //与默认配置bentRecon不符,暂不实现todo @@ -222,22 +226,15 @@ namespace Recon { { allHitMaps.push_back(buildMatrixR.hitmap); } - #pragma omp parallel for num_threads(2) - for (int i =0; i<2; i++){ - if (i ==0){ - if(!data.isNull()) - { - Matrix sosValue = solveParameterIterator(buildMatrixR.M, b, dims, false, transParams::nonNeg)[0][0]; - result.outSOS = slownessToSOS(sosValue, SOS_IN_WATER) ; - } - } - else{ - if(!dataAtt.isNull()) - { - Matrix attValue = solveParameterIterator(buildMatrixR.M, bAtt, dims, false, transParams::nonNeg,1)[0][0]; - result.outATT = attValue/100 ; - } - } + if(!data.isNull()) + { + Matrix sosValue = solveParameterIterator(buildMatrixR.M, b, dims, false, transParams::nonNeg)[0][0]; + result.outSOS = slownessToSOS(sosValue, SOS_IN_WATER) ; + } + if(!dataAtt.isNull()) + { + Matrix attValue = solveParameterIterator(buildMatrixR.M, bAtt, dims, false, transParams::nonNeg)[0][0]; + result.outATT = attValue/100 ; } }