Add check and scale function (unfinished)

This commit is contained in:
kradchen
2023-05-31 13:55:39 +08:00
parent a82b89966a
commit 4270a11ea4
4 changed files with 31 additions and 4 deletions

View File

@@ -1,5 +1,6 @@
#include "TVAL.h"
#include "Function2D.h"
#include "Matrix.h"
#include "tval3gpu3d.h"
#include <algorithm>
@@ -7,8 +8,31 @@
namespace Recon
{
Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims,int device, const TVALOptions& opt)
void checkAndScale(Aurora::Sparse& M, Aurora::Matrix& b,size_t n){
//TODO:暂时啥都没做
//无用的代码?
//opts.scale_A = true;
//opts.consist_mu = false;
//TODO:以下意义不明,暂定为判断是否为复数存储
//if ~isreal(M * rand(n,1))
// eopts.isreal = false;
//end
bool isreal = M.getValVector().getValueType() == Aurora::Normal;
//TODO以下操作意义不明
//fh = @(x) ((M * x)' * M)';
//s2 = eigs(fh,n,1,'lm',eopts);
//if real(s2) > 1 + 1e-10
// b = b ./ sqrt(s2);
// M = M ./ sqrt( s2);
//end
}
Aurora::Matrix callTval3(Aurora::Sparse& M, Aurora::Matrix& b,const Aurora::Matrix& dims,int device, TVALOptions& opt)
{
checkAndScale(M,b,(size_t)Aurora::prod(dims).getScalar());
int * xIdxs = new int[M.getColVector().getDataSize()];
std::copy(M.getColVector().getData(),M.getColVector().getData()+M.getColVector().getDataSize(),xIdxs);
int * yIdxs = new int[M.getRowVector().getDataSize()];

View File

@@ -5,7 +5,9 @@
#include "tvalstruct.h"
namespace Recon {
Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims, int device, const struct TVALOptions& options);
Aurora::Matrix callTval3(Aurora::Sparse &M, Aurora::Matrix &b,
const Aurora::Matrix &dims, int device,
struct TVALOptions &options);
}
#endif // __TVAL_H__

View File

@@ -21,7 +21,7 @@ namespace Recon
bool nonNeg = false;
};
Aurora::Matrix solve( Aurora::Sparse& M, const Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){
Aurora::Matrix solve( Aurora::Sparse& M, Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){
if (Recon::transParams::name.empty()){
Recon::transParams::name = "TVAL3";
}
@@ -82,6 +82,7 @@ namespace Recon
options.TVAL3Beta0 = Recon::transParams::betaValues[i];
options.nonNeg = nonNeg;
solveResult[j] = solve(M, b, dims, transParams::maxIter, options);
solveResult[j].forceReshape(dims[0], dims[1], dims.getDataSize()<3?1:dims[2]);
}
result[i] = solveResult;
}

View File

@@ -5,7 +5,7 @@
#include <vector>
namespace Recon {
std::vector<std::vector<Aurora::Matrix>>
solveParameterIterator(Aurora::Sparse M, const Aurora::Matrix &b,
solveParameterIterator(Aurora::Sparse M, Aurora::Matrix &b,
const Aurora::Matrix &dims, bool oneIter = true, bool nonNeg = false);
} // namespace Recon