Improve sensitivityCalc performance.

This commit is contained in:
kradchen
2023-05-19 09:52:27 +08:00
parent 7088602b59
commit 40c882e3a9

View File

@@ -131,16 +131,18 @@ Aurora::Matrix precalcSensitivityForTAS(const Aurora::Matrix &aMSensChar,
const Aurora::Matrix &aMStartPositions, const Aurora::Matrix &aMStartPositions,
const Aurora::Matrix &aMStartNormals, const Aurora::Matrix &aMStartNormals,
const Aurora::Matrix &aMEndPositions) { const Aurora::Matrix &aMEndPositions) {
auto sens = Aurora::zeros(aMEndPositions.getDimSize(1), auto sens = Aurora::zeros(aMEndPositions.getDimSize(1),
aMStartPositions.getDimSize(1)); aMStartPositions.getDimSize(1));
for (size_t i = 0; i < aMStartPositions.getDimSize(1); ++i) { //有修改多线程化可提升10倍速
auto dirVector = aMEndPositions - #pragma omp parallel for
Aurora::repmat(aMStartPositions(Aurora::$, i).toMatrix(), for (size_t i = 0; i < aMStartPositions.getDimSize(1); ++i) {
1, aMEndPositions.getDimSize(1)); auto dirVector = aMEndPositions -
sens(Aurora::$, i) = getSensitivity( Aurora::repmat(aMStartPositions(Aurora::$, i).toMatrix(),
aMSensChar, aMStartNormals(Aurora::$, i).toMatrix(), dirVector); 1, aMEndPositions.getDimSize(1));
} sens(Aurora::$, i) = getSensitivity(
return sens; aMSensChar, aMStartNormals(Aurora::$, i).toMatrix(), dirVector);
}
return sens;
} }
std::vector<Aurora::Matrix> std::vector<Aurora::Matrix>
@@ -150,35 +152,35 @@ combineSensitivity(const Aurora::Matrix &aVSenderTASRange,
const Aurora::Matrix &aVReceiverElementRange, const Aurora::Matrix &aVReceiverElementRange,
const Aurora::Matrix &aMSenderSens, const Aurora::Matrix &aMSenderSens,
const Aurora::Matrix &aMReceiverSens) { const Aurora::Matrix &aMReceiverSens) {
double maxReceiverElementRange = double maxReceiverElementRange =
Aurora::max(aVReceiverElementRange).getScalar(); Aurora::max(aVReceiverElementRange).getScalar();
double maxReceiverTASRange = Aurora::max(aVReceiverTASRange).getScalar(); double maxReceiverTASRange = Aurora::max(aVReceiverTASRange).getScalar();
double maxSenderElementRange = Aurora::max(aVSenderElementRange).getScalar(); double maxSenderElementRange = Aurora::max(aVSenderElementRange).getScalar();
double maxSenderTASRange = Aurora::max(aVSenderTASRange).getScalar(); double maxSenderTASRange = Aurora::max(aVSenderTASRange).getScalar();
std::vector<Aurora::Matrix> ret; std::vector<Aurora::Matrix> ret;
for (size_t i = 0; i < maxSenderTASRange; ++i) { for (size_t i = 0; i < maxSenderTASRange; ++i) {
ret.emplace_back(Aurora::zeros(maxReceiverElementRange, maxReceiverTASRange, ret.emplace_back(Aurora::zeros(maxReceiverElementRange, maxReceiverTASRange,
maxSenderElementRange)); maxSenderElementRange));
} }
size_t countSE = 0; //有修改多线程化可提升10倍速
for (size_t i = 0; i < aVSenderTASRange.getDataSize(); i++) { #pragma omp parallel for
auto se = aVReceiverTASRange.getData()[i]; for (size_t i = 0; i < aVSenderTASRange.getDataSize(); i++) {
for (size_t j = 0; j < aVSenderElementRange.getDataSize(); j++) { auto se = aVReceiverTASRange.getData()[i];
auto sn = aVSenderElementRange.getData()[j]; for (size_t j = 0; j < aVSenderElementRange.getDataSize(); j++) {
size_t countRE = 0; auto sn = aVSenderElementRange.getData()[j];
for (size_t k = 0; k < aVReceiverTASRange.getDataSize(); k++) { size_t countSE = i * aVSenderElementRange.getDataSize() + j;
auto re = aVReceiverTASRange.getData()[k]; for (size_t k = 0; k < aVReceiverTASRange.getDataSize(); k++) {
for (size_t n = 0; n < aVReceiverElementRange.getDataSize(); n++) { auto re = aVReceiverTASRange.getData()[k];
auto rn = aVReceiverElementRange.getData()[n]; for (size_t n = 0; n < aVReceiverElementRange.getDataSize(); n++) {
ret[se - 1](rn - 1, re - 1, sn - 1) = size_t countRE = k*aVReceiverElementRange.getDataSize()+n;
aMSenderSens(countRE, countSE).toMatrix().getScalar() * auto rn = aVReceiverElementRange.getData()[n];
aMReceiverSens(countSE, countRE).toMatrix().getScalar(); ret[se - 1](rn - 1, re - 1, sn - 1) =
countRE++; aMSenderSens(countRE, countSE).toMatrix().getScalar() *
} aMReceiverSens(countSE, countRE).toMatrix().getScalar();
} }
countSE++; }
}
} }
}
return ret; return ret;
} }