From 4270a11ea427b7173e53b3371e0ef247cf9f8efe Mon Sep 17 00:00:00 2001 From: kradchen Date: Wed, 31 May 2023 13:55:39 +0800 Subject: [PATCH] Add check and scale function (unfinished) --- .../solvingEquationSystem/TVAL/TVAL.cpp | 26 ++++++++++++++++++- .../solvingEquationSystem/TVAL/TVAL.h | 4 ++- .../solvingEquationSystem/solve.cpp | 3 ++- .../solvingEquationSystem/solve.h | 2 +- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.cpp b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.cpp index b7cebc4..abee434 100644 --- a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.cpp +++ b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.cpp @@ -1,5 +1,6 @@ #include "TVAL.h" +#include "Function2D.h" #include "Matrix.h" #include "tval3gpu3d.h" #include @@ -7,8 +8,31 @@ namespace Recon { - Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims,int device, const TVALOptions& opt) + void checkAndScale(Aurora::Sparse& M, Aurora::Matrix& b,size_t n){ + //TODO:暂时啥都没做 + + //无用的代码? + //opts.scale_A = true; + //opts.consist_mu = false; + + //TODO:以下意义不明,暂定为判断是否为复数存储 + //if ~isreal(M * rand(n,1)) + // eopts.isreal = false; + //end + + bool isreal = M.getValVector().getValueType() == Aurora::Normal; + + //TODO:以下操作意义不明 + //fh = @(x) ((M * x)' * M)'; + //s2 = eigs(fh,n,1,'lm',eopts); + //if real(s2) > 1 + 1e-10 + // b = b ./ sqrt(s2); + // M = M ./ sqrt( s2); + //end + } + Aurora::Matrix callTval3(Aurora::Sparse& M, Aurora::Matrix& b,const Aurora::Matrix& dims,int device, TVALOptions& opt) { + checkAndScale(M,b,(size_t)Aurora::prod(dims).getScalar()); int * xIdxs = new int[M.getColVector().getDataSize()]; std::copy(M.getColVector().getData(),M.getColVector().getData()+M.getColVector().getDataSize(),xIdxs); int * yIdxs = new int[M.getRowVector().getDataSize()]; diff --git a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.h b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.h index 4800006..e870fe8 100644 --- a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.h +++ b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/TVAL/TVAL.h @@ -5,7 +5,9 @@ #include "tvalstruct.h" namespace Recon { - Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims, int device, const struct TVALOptions& options); +Aurora::Matrix callTval3(Aurora::Sparse &M, Aurora::Matrix &b, + const Aurora::Matrix &dims, int device, + struct TVALOptions &options); } #endif // __TVAL_H__ \ No newline at end of file diff --git a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.cpp b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.cpp index 117caa3..89641ab 100644 --- a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.cpp +++ b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.cpp @@ -21,7 +21,7 @@ namespace Recon bool nonNeg = false; }; - Aurora::Matrix solve( Aurora::Sparse& M, const Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){ + Aurora::Matrix solve( Aurora::Sparse& M, Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){ if (Recon::transParams::name.empty()){ Recon::transParams::name = "TVAL3"; } @@ -82,6 +82,7 @@ namespace Recon options.TVAL3Beta0 = Recon::transParams::betaValues[i]; options.nonNeg = nonNeg; solveResult[j] = solve(M, b, dims, transParams::maxIter, options); + solveResult[j].forceReshape(dims[0], dims[1], dims.getDataSize()<3?1:dims[2]); } result[i] = solveResult; } diff --git a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h index 2b097a6..06b0df8 100644 --- a/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h +++ b/src/transmissionReconstruction/reconstruction/solvingEquationSystem/solve.h @@ -5,7 +5,7 @@ #include namespace Recon { std::vector> - solveParameterIterator(Aurora::Sparse M, const Aurora::Matrix &b, + solveParameterIterator(Aurora::Sparse M, Aurora::Matrix &b, const Aurora::Matrix &dims, bool oneIter = true, bool nonNeg = false); } // namespace Recon