#include "solve.h" #include #include "Function3D.h" #include "Matrix.h" #include "config/config.h" #include "transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.h" #include "tvalstruct.h" namespace Recon { struct TVAL3SolverOptions{ Aurora::Matrix gpuSelectionList; double TVAL3MU; double TVAL3MU0; double TVAL3Beta; double TVAL3Beta0; bool nonNeg = false; }; Aurora::Matrix solve( Aurora::Sparse& M, const Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){ if (Recon::transParams::name.empty()){ Recon::transParams::name = "TVAL3"; } if (Recon::transParams::name == "TVAL3") { //callTval3 TVALOptions opt; opt.bent = false; opt.tol = 1E-10; opt.maxit = niter; opt.TVnorm = 2; opt.disp = false; opt.mu0 = solverOptions.TVAL3MU0; opt.mu = solverOptions.TVAL3MU; opt.beta = solverOptions.TVAL3Beta; opt.beta0 = solverOptions.TVAL3Beta0; int device = (int)solverOptions.gpuSelectionList[0]; return callTval3(M, b, dims, device, opt); } //SART else{ //TODO:待实现,先实现默认的TVAL3 return Aurora::Matrix(); } } std::vector> solveParameterIterator(Aurora::Sparse M, const Aurora::Matrix &b, const Aurora::Matrix &dims, bool oneIter, bool nonNeg) { if (Recon::transParams::name == "TVAL3"){ std::vector> result(Recon::transParams::muValues.getDataSize()); if (Recon::transParams::muValues.isNull()){ Recon::transParams::muValues = Aurora::ones(1,1); Recon::transParams::muValues[0] = 24; } if (Recon::transParams::betaValues.isNull()){ Recon::transParams::betaValues = Aurora::ones(1,1); Recon::transParams::betaValues[0] = 1; } if (oneIter){ auto temp = Aurora::ones(1,1); temp[0] = Recon::transParams::muValues[0]; Recon::transParams::muValues = temp; auto temp2 = Aurora::ones(1,1); temp2[0] = Recon::transParams::betaValues[0]; Recon::transParams::betaValues = temp2; } TVAL3SolverOptions options; options.gpuSelectionList = Recon::transParams::gpuSelectionList; for (size_t i = 0; i < Recon::transParams::muValues.getDataSize(); i++) { options.TVAL3MU = Recon::transParams::muValues[i]; options.TVAL3MU0 = Recon::transParams::muValues[i]; std::vector solveResult(Recon::transParams::betaValues.getDataSize()); for (size_t j = 0; j < Recon::transParams::betaValues.getDataSize(); j++) { options.TVAL3Beta = Recon::transParams::betaValues[i]; options.TVAL3Beta0 = Recon::transParams::betaValues[i]; options.nonNeg = nonNeg; solveResult[j] = solve(M, b, dims, transParams::maxIter, options); } result[i] = solveResult; } return result; } //SART else{ std::vector> result; //TODO:暂时未实现 return result; } } }