Fix ceil, round, floor used by complex bug.

This commit is contained in:
sunwen
2023-11-17 17:22:33 +08:00
parent 654dd8e6c5
commit aabe9d1fd6

View File

@@ -200,64 +200,64 @@ Aurora::Matrix Aurora::imag(const Aurora::Matrix &matrix) {
} }
Aurora::Matrix Aurora::ceil(const Aurora::Matrix &matrix) { Aurora::Matrix Aurora::ceil(const Aurora::Matrix &matrix) {
auto output = malloc(matrix.getDataSize()); auto output = malloc(matrix.getDataSize()*matrix.getValueType());
//for real part //for real part
vsCeilI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE, output, SAME_STRIDE); vsCeilI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE*matrix.getValueType(), output, SAME_STRIDE*matrix.getValueType());
if (matrix.getValueType() == Complex) { if (matrix.getValueType() == Complex) {
//for imag part //for imag part
vsCeilI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE, output + 1, SAME_STRIDE); vsCeilI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType(), output + 1, SAME_STRIDE*matrix.getValueType());
} }
return Aurora::Matrix::New(output, matrix); return Aurora::Matrix::New(output, matrix);
} }
Aurora::Matrix Aurora::ceil(const Aurora::Matrix &&matrix) { Aurora::Matrix Aurora::ceil(const Aurora::Matrix &&matrix) {
//for real part //for real part
vsCeilI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE, matrix.getData(), SAME_STRIDE); vsCeilI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE*matrix.getValueType(), matrix.getData(), SAME_STRIDE*matrix.getValueType());
if (matrix.getValueType() == Complex) { if (matrix.getValueType() == Complex) {
//for imag part //for imag part
vsCeilI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE, matrix.getData() + 1, SAME_STRIDE); vsCeilI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType());
} }
return matrix; return matrix;
} }
Aurora::Matrix Aurora::round(const Aurora::Matrix &matrix) { Aurora::Matrix Aurora::round(const Aurora::Matrix &matrix) {
auto output = malloc(matrix.getDataSize()); auto output = malloc(matrix.getDataSize()*matrix.getValueType());
//for real part //for real part
vsRoundI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE, output, SAME_STRIDE); vsRoundI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE*matrix.getValueType(), output, SAME_STRIDE*matrix.getValueType());
if (matrix.getValueType() == Complex) { if (matrix.getValueType() == Complex) {
//for imag part //for imag part
vsRoundI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE, output + 1, SAME_STRIDE); vsRoundI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType(), output + 1, SAME_STRIDE*matrix.getValueType());
} }
return Aurora::Matrix::New(output, matrix); return Aurora::Matrix::New(output, matrix);
} }
Aurora::Matrix Aurora::round(const Aurora::Matrix &&matrix) { Aurora::Matrix Aurora::round(const Aurora::Matrix &&matrix) {
//for real part //for real part
vsRoundI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE, matrix.getData(), SAME_STRIDE); vsRoundI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE*matrix.getValueType(), matrix.getData(), SAME_STRIDE*matrix.getValueType());
if (matrix.getValueType() == Complex) { if (matrix.getValueType() == Complex) {
//for imag part //for imag part
vsRoundI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE, matrix.getData() + 1, SAME_STRIDE); vsRoundI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType());
} }
return matrix; return matrix;
} }
Aurora::Matrix Aurora::floor(const Aurora::Matrix &matrix) { Aurora::Matrix Aurora::floor(const Aurora::Matrix &matrix) {
auto output = malloc(matrix.getDataSize()); auto output = malloc(matrix.getDataSize()*matrix.getValueType());
//for real part //for real part
vsFloorI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE, output, SAME_STRIDE); vsFloorI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE*matrix.getValueType(), output, SAME_STRIDE*matrix.getValueType());
if (matrix.getValueType() == Complex) { if (matrix.getValueType() == Complex) {
//for imag part //for imag part
vsFloorI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE, output + 1, SAME_STRIDE); vsFloorI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType(), output + 1, SAME_STRIDE*matrix.getValueType());
} }
return Aurora::Matrix::New(output, matrix); return Aurora::Matrix::New(output, matrix);
} }
Aurora::Matrix Aurora::floor(const Aurora::Matrix &&matrix) { Aurora::Matrix Aurora::floor(const Aurora::Matrix &&matrix) {
//for real part //for real part
vsFloorI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE, matrix.getData(), SAME_STRIDE); vsFloorI(matrix.getDataSize(), matrix.getData(), SAME_STRIDE*matrix.getValueType(), matrix.getData(), SAME_STRIDE*matrix.getValueType());
if (matrix.getValueType() == Complex) { if (matrix.getValueType() == Complex) {
//for imag part //for imag part
vsFloorI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE, matrix.getData() + 1, SAME_STRIDE); vsFloorI(matrix.getDataSize(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType(), matrix.getData() + 1, SAME_STRIDE*matrix.getValueType());
} }
return matrix; return matrix;
} }