Add check and scale function (unfinished)

This commit is contained in:
kradchen
2023-05-31 13:55:39 +08:00
parent a82b89966a
commit 4270a11ea4
4 changed files with 31 additions and 4 deletions

View File

@@ -1,5 +1,6 @@
#include "TVAL.h" #include "TVAL.h"
#include "Function2D.h"
#include "Matrix.h" #include "Matrix.h"
#include "tval3gpu3d.h" #include "tval3gpu3d.h"
#include <algorithm> #include <algorithm>
@@ -7,8 +8,31 @@
namespace Recon namespace Recon
{ {
Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims,int device, const TVALOptions& opt) void checkAndScale(Aurora::Sparse& M, Aurora::Matrix& b,size_t n){
//TODO:暂时啥都没做
//无用的代码?
//opts.scale_A = true;
//opts.consist_mu = false;
//TODO:以下意义不明,暂定为判断是否为复数存储
//if ~isreal(M * rand(n,1))
// eopts.isreal = false;
//end
bool isreal = M.getValVector().getValueType() == Aurora::Normal;
//TODO以下操作意义不明
//fh = @(x) ((M * x)' * M)';
//s2 = eigs(fh,n,1,'lm',eopts);
//if real(s2) > 1 + 1e-10
// b = b ./ sqrt(s2);
// M = M ./ sqrt( s2);
//end
}
Aurora::Matrix callTval3(Aurora::Sparse& M, Aurora::Matrix& b,const Aurora::Matrix& dims,int device, TVALOptions& opt)
{ {
checkAndScale(M,b,(size_t)Aurora::prod(dims).getScalar());
int * xIdxs = new int[M.getColVector().getDataSize()]; int * xIdxs = new int[M.getColVector().getDataSize()];
std::copy(M.getColVector().getData(),M.getColVector().getData()+M.getColVector().getDataSize(),xIdxs); std::copy(M.getColVector().getData(),M.getColVector().getData()+M.getColVector().getDataSize(),xIdxs);
int * yIdxs = new int[M.getRowVector().getDataSize()]; int * yIdxs = new int[M.getRowVector().getDataSize()];

View File

@@ -5,7 +5,9 @@
#include "tvalstruct.h" #include "tvalstruct.h"
namespace Recon { namespace Recon {
Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims, int device, const struct TVALOptions& options); Aurora::Matrix callTval3(Aurora::Sparse &M, Aurora::Matrix &b,
const Aurora::Matrix &dims, int device,
struct TVALOptions &options);
} }
#endif // __TVAL_H__ #endif // __TVAL_H__

View File

@@ -21,7 +21,7 @@ namespace Recon
bool nonNeg = false; bool nonNeg = false;
}; };
Aurora::Matrix solve( Aurora::Sparse& M, const Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){ Aurora::Matrix solve( Aurora::Sparse& M, Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){
if (Recon::transParams::name.empty()){ if (Recon::transParams::name.empty()){
Recon::transParams::name = "TVAL3"; Recon::transParams::name = "TVAL3";
} }
@@ -82,6 +82,7 @@ namespace Recon
options.TVAL3Beta0 = Recon::transParams::betaValues[i]; options.TVAL3Beta0 = Recon::transParams::betaValues[i];
options.nonNeg = nonNeg; options.nonNeg = nonNeg;
solveResult[j] = solve(M, b, dims, transParams::maxIter, options); solveResult[j] = solve(M, b, dims, transParams::maxIter, options);
solveResult[j].forceReshape(dims[0], dims[1], dims.getDataSize()<3?1:dims[2]);
} }
result[i] = solveResult; result[i] = solveResult;
} }

View File

@@ -5,7 +5,7 @@
#include <vector> #include <vector>
namespace Recon { namespace Recon {
std::vector<std::vector<Aurora::Matrix>> std::vector<std::vector<Aurora::Matrix>>
solveParameterIterator(Aurora::Sparse M, const Aurora::Matrix &b, solveParameterIterator(Aurora::Sparse M, Aurora::Matrix &b,
const Aurora::Matrix &dims, bool oneIter = true, bool nonNeg = false); const Aurora::Matrix &dims, bool oneIter = true, bool nonNeg = false);
} // namespace Recon } // namespace Recon