Files
UR/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.cpp
2023-05-30 17:08:19 +08:00

98 lines
3.6 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "solve.h"
#include <cstddef>
#include "Function3D.h"
#include "Matrix.h"
#include "config/config.h"
#include "transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.h"
#include "tvalstruct.h"
namespace Recon
{
struct TVAL3SolverOptions{
Aurora::Matrix gpuSelectionList;
double TVAL3MU;
double TVAL3MU0;
double TVAL3Beta;
double TVAL3Beta0;
bool nonNeg = false;
};
Aurora::Matrix solve( Aurora::Sparse& M, const Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){
if (Recon::transParams::name.empty()){
Recon::transParams::name = "TVAL3";
}
if (Recon::transParams::name == "TVAL3")
{
//callTval3
TVALOptions opt;
opt.bent = false;
opt.tol = 1E-10;
opt.maxit = niter;
opt.TVnorm = 2;
opt.disp = false;
opt.mu0 = solverOptions.TVAL3MU0;
opt.mu = solverOptions.TVAL3MU;
opt.beta = solverOptions.TVAL3Beta;
opt.beta0 = solverOptions.TVAL3Beta0;
int device = (int)solverOptions.gpuSelectionList[0];
return callTval3(M, b, dims, device, opt);
}
//SART
else{
//TODO:待实现先实现默认的TVAL3
return Aurora::Matrix();
}
}
std::vector<std::vector<Aurora::Matrix>> solveParameterIterator(Aurora::Sparse M, const Aurora::Matrix &b,
const Aurora::Matrix &dims, bool oneIter, bool nonNeg)
{
if (Recon::transParams::name == "TVAL3"){
std::vector<std::vector<Aurora::Matrix>> result(Recon::transParams::muValues.getDataSize());
if (Recon::transParams::muValues.isNull()){
Recon::transParams::muValues = Aurora::ones(1,1);
Recon::transParams::muValues[0] = 24;
}
if (Recon::transParams::betaValues.isNull()){
Recon::transParams::betaValues = Aurora::ones(1,1);
Recon::transParams::betaValues[0] = 1;
}
if (oneIter){
auto temp = Aurora::ones(1,1);
temp[0] = Recon::transParams::muValues[0];
Recon::transParams::muValues = temp;
auto temp2 = Aurora::ones(1,1);
temp2[0] = Recon::transParams::betaValues[0];
Recon::transParams::betaValues = temp2;
}
TVAL3SolverOptions options;
options.gpuSelectionList = Recon::transParams::gpuSelectionList;
for (size_t i = 0; i < Recon::transParams::muValues.getDataSize(); i++)
{
options.TVAL3MU = Recon::transParams::muValues[i];
options.TVAL3MU0 = Recon::transParams::muValues[i];
std::vector<Aurora::Matrix> solveResult(Recon::transParams::betaValues.getDataSize());
for (size_t j = 0; j < Recon::transParams::betaValues.getDataSize(); j++)
{
options.TVAL3Beta = Recon::transParams::betaValues[i];
options.TVAL3Beta0 = Recon::transParams::betaValues[i];
options.nonNeg = nonNeg;
solveResult[j] = solve(M, b, dims, transParams::maxIter, options);
}
result[i] = solveResult;
}
return result;
}
//SART
else{
std::vector<std::vector<Aurora::Matrix>> result;
//TODO:暂时未实现
return result;
}
}
}