refactor convertfp16tofloat
This commit is contained in:
@@ -4,74 +4,99 @@
|
|||||||
#include <emmintrin.h>
|
#include <emmintrin.h>
|
||||||
#include <immintrin.h>
|
#include <immintrin.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
namespace {
|
||||||
|
const ushort CONVERT_AND_VALUE = 15;
|
||||||
|
// andblack
|
||||||
|
const __m128i andBlock = _mm_set_epi16(15, 15, 15, 15, 15, 15, 15, 15);
|
||||||
|
const __m128i andBlock2 =
|
||||||
|
_mm_set_epi16(2047, 2047, 2047, 2047, 2047, 2047, 2047, 2047);
|
||||||
|
const __m128i zeroBlock = _mm_set_epi16(0, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
const __m128i oneBlock = _mm_set_epi16(1, 1, 1, 1, 1, 1, 1, 1);
|
||||||
|
const __m128i twokBlock =
|
||||||
|
_mm_set_epi16(2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048);
|
||||||
|
const uint CONVERT_ADD_VALUE = UINT32_MAX - 4095;
|
||||||
|
void convert(short * ptr, double* des,bool single = false){
|
||||||
|
// 初始化值
|
||||||
|
auto value = _mm_set_epi16(ptr[0], ptr[1], ptr[2], ptr[3], single?ptr[0]:ptr[4], single?ptr[0]:ptr[5],
|
||||||
|
single?ptr[0]:ptr[6], single?ptr[0]:ptr[7]);
|
||||||
|
auto uvalue = _mm_set_epi16(
|
||||||
|
(ushort)ptr[0], (ushort)ptr[1], (ushort)ptr[2], (ushort)ptr[3],
|
||||||
|
(ushort)(single?ptr[0]:ptr[4]), (ushort)(single?ptr[0]:ptr[5]),
|
||||||
|
(ushort)(single?ptr[0]:ptr[6]), (ushort)(single?ptr[0]:ptr[7]));
|
||||||
|
// 位移
|
||||||
|
auto sign_bit = _mm_srli_epi16(value, 15); // 右移16位取符号位
|
||||||
|
auto exponent = _mm_srli_epi16(uvalue, 11);
|
||||||
|
// and
|
||||||
|
exponent = _mm_and_si128(exponent, andBlock);
|
||||||
|
// and ,then convert to int 32 bits
|
||||||
|
auto fraction3 = _mm256_cvtepi16_epi32(_mm_and_si128(uvalue, andBlock2));
|
||||||
|
auto hidden_bit_mask =
|
||||||
|
(_mm_cmp_epi16_mask(sign_bit, oneBlock, _MM_CMPINT_EQ) &
|
||||||
|
_mm_cmp_epi16_mask(exponent, zeroBlock, _MM_CMPINT_EQ)) |
|
||||||
|
(_mm_cmp_epi16_mask(sign_bit, zeroBlock, _MM_CMPINT_EQ) &
|
||||||
|
_mm_cmp_epi16_mask(exponent, zeroBlock, _MM_CMPINT_NE));
|
||||||
|
auto hidden_bit16 = _mm_maskz_set1_epi16(hidden_bit_mask, 2048);
|
||||||
|
auto hidden_bit32 = _mm256_cvtepi16_epi32(hidden_bit16);
|
||||||
|
auto outputBlock = _mm256_add_epi32(fraction3, hidden_bit32);
|
||||||
|
auto sign_bit_add_value = _mm256_maskz_set1_epi32(
|
||||||
|
_mm_cmp_epi16_mask(sign_bit, oneBlock, _MM_CMPINT_EQ),
|
||||||
|
CONVERT_ADD_VALUE);
|
||||||
|
outputBlock = _mm256_add_epi32(outputBlock, sign_bit_add_value);
|
||||||
|
auto exponent_mask =
|
||||||
|
_mm_cmp_epi16_mask(oneBlock, exponent, _MM_CMPINT_LT);
|
||||||
|
exponent = _mm_sub_epi16(exponent, oneBlock);
|
||||||
|
auto exponent32 = _mm256_cvtepi16_epi32(exponent);
|
||||||
|
auto zeroBlock32 = _mm256_cvtepi16_epi32(zeroBlock);
|
||||||
|
auto offsetCount =
|
||||||
|
_mm256_mask_blend_epi32(exponent_mask, zeroBlock32, exponent32);
|
||||||
|
|
||||||
|
outputBlock = _mm256_sllv_epi32(outputBlock, offsetCount);
|
||||||
|
|
||||||
|
des[3] = _mm256_extract_epi32(outputBlock, 4);
|
||||||
|
des[2] = _mm256_extract_epi32(outputBlock, 5);
|
||||||
|
des[1] = _mm256_extract_epi32(outputBlock, 6);
|
||||||
|
des[0] = _mm256_extract_epi32(outputBlock, 7);
|
||||||
|
if(single) return;
|
||||||
|
des[7] = _mm256_extract_epi32(outputBlock, 0);
|
||||||
|
des[6] = _mm256_extract_epi32(outputBlock, 1);
|
||||||
|
des[5] = _mm256_extract_epi32(outputBlock, 2);
|
||||||
|
des[4] = _mm256_extract_epi32(outputBlock, 3);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Aurora::Matrix Recon::convertfp16tofloat(Aurora::Matrix aMatrix) {
|
Aurora::Matrix Recon::convertfp16tofloat(Aurora::Matrix aMatrix) {
|
||||||
auto input = aMatrix.getData();
|
auto input = aMatrix.getData();
|
||||||
// uint16变换为float(32位)输出大小翻倍
|
// uint16变换为float(32位)输出大小翻倍
|
||||||
auto output = Aurora::malloc(aMatrix.getDataSize() * 4);
|
auto output = Aurora::malloc(aMatrix.getDataSize() * 4);
|
||||||
size_t rows = aMatrix.getDataSize() * sizeof(double) / sizeof(short);
|
size_t rows = aMatrix.getDataSize() * sizeof(double) / sizeof(short);
|
||||||
size_t total_count = aMatrix.getDataSize();
|
size_t total_count = aMatrix.getDataSize();
|
||||||
const ushort CONVERT_AND_VALUE = 15;
|
|
||||||
// andblack
|
|
||||||
__m128i andBlock = _mm_set_epi16(15, 15, 15, 15, 15, 15, 15, 15);
|
|
||||||
__m128i andBlock2 =
|
|
||||||
_mm_set_epi16(2047, 2047, 2047, 2047, 2047, 2047, 2047, 2047);
|
|
||||||
__m128i zeroBlock = _mm_set_epi16(0, 0, 0, 0, 0, 0, 0, 0);
|
|
||||||
__m128i oneBlock = _mm_set_epi16(1, 1, 1, 1, 1, 1, 1, 1);
|
|
||||||
__m128i twokBlock =
|
|
||||||
_mm_set_epi16(2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048);
|
|
||||||
uint CONVERT_ADD_VALUE = UINT32_MAX - 4095;
|
|
||||||
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (size_t i = 0; i < total_count; i += 2) {
|
|
||||||
// 循环展开以避免过度的线程调用
|
|
||||||
if (i + 2 < total_count) {
|
|
||||||
auto ptr = (short *)(input + i);
|
|
||||||
// 初始化值
|
|
||||||
auto value = _mm_set_epi16(ptr[0], ptr[1], ptr[2], ptr[3], ptr[4], ptr[5],
|
|
||||||
ptr[6], ptr[7]);
|
|
||||||
auto uvalue = _mm_set_epi16(
|
|
||||||
(ushort)ptr[0], (ushort)ptr[1], (ushort)ptr[2], (ushort)ptr[3],
|
|
||||||
(ushort)ptr[4], (ushort)ptr[5], (ushort)ptr[6], (ushort)ptr[7]);
|
|
||||||
// 位移
|
|
||||||
auto sign_bit = _mm_srli_epi16(value, 15); // 右移16位取符号位
|
|
||||||
auto exponent = _mm_srli_epi16(uvalue, 11);
|
|
||||||
// and
|
|
||||||
exponent = _mm_and_si128(exponent, andBlock);
|
|
||||||
// and ,then convert to int 32 bits
|
|
||||||
auto fraction3 = _mm256_cvtepi16_epi32(_mm_and_si128(uvalue, andBlock2));
|
|
||||||
auto hidden_bit_mask =
|
|
||||||
(_mm_cmp_epi16_mask(sign_bit, oneBlock, _MM_CMPINT_EQ) &
|
|
||||||
_mm_cmp_epi16_mask(exponent, zeroBlock, _MM_CMPINT_EQ)) |
|
|
||||||
(_mm_cmp_epi16_mask(sign_bit, zeroBlock, _MM_CMPINT_EQ) &
|
|
||||||
_mm_cmp_epi16_mask(exponent, zeroBlock, _MM_CMPINT_NE));
|
|
||||||
auto hidden_bit16 = _mm_maskz_set1_epi16(hidden_bit_mask, 2048);
|
|
||||||
auto hidden_bit32 = _mm256_cvtepi16_epi32(hidden_bit16);
|
|
||||||
auto outputBlock = _mm256_add_epi32(fraction3, hidden_bit32);
|
|
||||||
auto sign_bit_add_value = _mm256_maskz_set1_epi32(
|
|
||||||
_mm_cmp_epi16_mask(sign_bit, oneBlock, _MM_CMPINT_EQ),
|
|
||||||
CONVERT_ADD_VALUE);
|
|
||||||
outputBlock = _mm256_add_epi32(outputBlock, sign_bit_add_value);
|
|
||||||
auto exponent_mask =
|
|
||||||
_mm_cmp_epi16_mask(oneBlock, exponent, _MM_CMPINT_LT);
|
|
||||||
exponent = _mm_sub_epi16(exponent, oneBlock);
|
|
||||||
auto exponent32 = _mm256_cvtepi16_epi32(exponent);
|
|
||||||
auto zeroBlock32 = _mm256_cvtepi16_epi32(zeroBlock);
|
|
||||||
auto offsetCount =
|
|
||||||
_mm256_mask_blend_epi32(exponent_mask, zeroBlock32, exponent32);
|
|
||||||
|
|
||||||
outputBlock = _mm256_sllv_epi32(outputBlock, offsetCount);
|
#pragma omp parallel for
|
||||||
double *des = output + i * 4;
|
for (size_t i = 0; i < total_count; i += 8) {
|
||||||
des[7] = (double)(int)_mm256_extract_epi32(outputBlock, 0);
|
// 循环展开以避免过度的线程调用
|
||||||
des[6] = (double)(int)_mm256_extract_epi32(outputBlock, 1);
|
if (i < total_count) {
|
||||||
des[5] = (double)(int)_mm256_extract_epi32(outputBlock, 2);
|
auto ptr = (short *)(input + i);
|
||||||
des[4] = (double)(int)_mm256_extract_epi32(outputBlock, 3);
|
double *des = output + i * 4;
|
||||||
des[3] = (double)(int)_mm256_extract_epi32(outputBlock, 4);
|
::convert(ptr, des,i+1>total_count);
|
||||||
des[2] = (double)(int)_mm256_extract_epi32(outputBlock, 5);
|
}
|
||||||
des[1] = (double)(int)_mm256_extract_epi32(outputBlock, 6);
|
if (i+2 < total_count) {
|
||||||
des[0] = (double)(int)_mm256_extract_epi32(outputBlock, 7);
|
auto ptr = (short *)(input + i + 2);
|
||||||
|
double *des = output + (i+2) * 4;
|
||||||
|
::convert(ptr, des,i+3>total_count);
|
||||||
|
}
|
||||||
|
if (i+4 < total_count) {
|
||||||
|
auto ptr = (short *)(input + i + 4);
|
||||||
|
double *des = output + (i+4) * 4;
|
||||||
|
::convert(ptr, des,i+5>total_count);
|
||||||
|
}
|
||||||
|
if (i+6 < total_count) {
|
||||||
|
auto ptr = (short *)(input + i + 6);
|
||||||
|
double *des = output + (i+6) * 4;
|
||||||
|
::convert(ptr, des,i+7>total_count);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return Aurora::Matrix::New(output, aMatrix.getDimSize(0),
|
||||||
return Aurora::Matrix::New(output, aMatrix.getDimSize(0),
|
aMatrix.getDimSize(1), aMatrix.getDimSize(2));
|
||||||
aMatrix.getDimSize(1), aMatrix.getDimSize(2));
|
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ TEST_F(Common_Test, convertfp16tofloat) {
|
|||||||
auto resultM = Recon::convertfp16tofloat(ma);
|
auto resultM = Recon::convertfp16tofloat(ma);
|
||||||
auto result = resultM.getData();
|
auto result = resultM.getData();
|
||||||
auto output = m.read("output");
|
auto output = m.read("output");
|
||||||
for (size_t i = 0; i<10; i++) {
|
for (size_t i = 0; i<count; i++) {
|
||||||
EXPECT_EQ(result[i], output.getData()[i])<<"index:"<<i<<",input:"<< ((short*)ma.getData())[i]<<",input2:"<<input.get()[i];
|
EXPECT_EQ(result[i], output.getData()[i])<<"index:"<<i<<",input:"<< ((short*)ma.getData())[i]<<",input2:"<<input.get()[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user