Add check and scale function (unfinished)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include "TVAL.h"
|
||||
|
||||
#include "Function2D.h"
|
||||
#include "Matrix.h"
|
||||
#include "tval3gpu3d.h"
|
||||
#include <algorithm>
|
||||
@@ -7,8 +8,31 @@
|
||||
|
||||
namespace Recon
|
||||
{
|
||||
Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims,int device, const TVALOptions& opt)
|
||||
void checkAndScale(Aurora::Sparse& M, Aurora::Matrix& b,size_t n){
|
||||
//TODO:暂时啥都没做
|
||||
|
||||
//无用的代码?
|
||||
//opts.scale_A = true;
|
||||
//opts.consist_mu = false;
|
||||
|
||||
//TODO:以下意义不明,暂定为判断是否为复数存储
|
||||
//if ~isreal(M * rand(n,1))
|
||||
// eopts.isreal = false;
|
||||
//end
|
||||
|
||||
bool isreal = M.getValVector().getValueType() == Aurora::Normal;
|
||||
|
||||
//TODO:以下操作意义不明
|
||||
//fh = @(x) ((M * x)' * M)';
|
||||
//s2 = eigs(fh,n,1,'lm',eopts);
|
||||
//if real(s2) > 1 + 1e-10
|
||||
// b = b ./ sqrt(s2);
|
||||
// M = M ./ sqrt( s2);
|
||||
//end
|
||||
}
|
||||
Aurora::Matrix callTval3(Aurora::Sparse& M, Aurora::Matrix& b,const Aurora::Matrix& dims,int device, TVALOptions& opt)
|
||||
{
|
||||
checkAndScale(M,b,(size_t)Aurora::prod(dims).getScalar());
|
||||
int * xIdxs = new int[M.getColVector().getDataSize()];
|
||||
std::copy(M.getColVector().getData(),M.getColVector().getData()+M.getColVector().getDataSize(),xIdxs);
|
||||
int * yIdxs = new int[M.getRowVector().getDataSize()];
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#include "tvalstruct.h"
|
||||
|
||||
namespace Recon {
|
||||
Aurora::Matrix callTval3(Aurora::Sparse& M, const Aurora::Matrix& b,const Aurora::Matrix& dims, int device, const struct TVALOptions& options);
|
||||
Aurora::Matrix callTval3(Aurora::Sparse &M, Aurora::Matrix &b,
|
||||
const Aurora::Matrix &dims, int device,
|
||||
struct TVALOptions &options);
|
||||
}
|
||||
|
||||
#endif // __TVAL_H__
|
||||
@@ -21,7 +21,7 @@ namespace Recon
|
||||
bool nonNeg = false;
|
||||
};
|
||||
|
||||
Aurora::Matrix solve( Aurora::Sparse& M, const Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){
|
||||
Aurora::Matrix solve( Aurora::Sparse& M, Aurora::Matrix& b, const Aurora::Matrix& dims, int niter, TVAL3SolverOptions solverOptions){
|
||||
if (Recon::transParams::name.empty()){
|
||||
Recon::transParams::name = "TVAL3";
|
||||
}
|
||||
@@ -82,6 +82,7 @@ namespace Recon
|
||||
options.TVAL3Beta0 = Recon::transParams::betaValues[i];
|
||||
options.nonNeg = nonNeg;
|
||||
solveResult[j] = solve(M, b, dims, transParams::maxIter, options);
|
||||
solveResult[j].forceReshape(dims[0], dims[1], dims.getDataSize()<3?1:dims[2]);
|
||||
}
|
||||
result[i] = solveResult;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
namespace Recon {
|
||||
std::vector<std::vector<Aurora::Matrix>>
|
||||
solveParameterIterator(Aurora::Sparse M, const Aurora::Matrix &b,
|
||||
solveParameterIterator(Aurora::Sparse M, Aurora::Matrix &b,
|
||||
const Aurora::Matrix &dims, bool oneIter = true, bool nonNeg = false);
|
||||
} // namespace Recon
|
||||
|
||||
|
||||
Reference in New Issue
Block a user