From f2b81eddc3cd0c0b414c6caa89790b23ee22337f Mon Sep 17 00:00:00 2001 From: kradchen Date: Wed, 13 Dec 2023 16:39:09 +0800 Subject: [PATCH] Add thrust iterator inherit from iterator_facade --- src/AuroraThrustIterator.cuh | 132 +++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 src/AuroraThrustIterator.cuh diff --git a/src/AuroraThrustIterator.cuh b/src/AuroraThrustIterator.cuh new file mode 100644 index 0000000..cf163dc --- /dev/null +++ b/src/AuroraThrustIterator.cuh @@ -0,0 +1,132 @@ +#ifndef __AURORATHRUSTITERATOR_H__ +#define __AURORATHRUSTITERATOR_H__ +#include +namespace Aurora { + + /** + * @brief 将一个向量通过循环读取来充作矩阵迭代器 + * + * @tparam ValueType + */ + template + class LoopVectorIterator:public thrust::iterator_facade< + LoopVectorIterator, + ValueType, + thrust::device_system_tag, + thrust::random_access_traversal_tag, + ValueType& >{ + public: + /** + * @brief 循环存取向量迭代器构造函数 + * @attention 特别注意本迭代器不会有结束值! + * + * @param ptr 原始指针 + * @param aVectorSize 向量长度 + */ + __host__ __device__ + LoopVectorIterator(ValueType* ptr, int aVectorSize=1) + : ptr_(ptr),vec_size_(aVectorSize){} + + __host__ __device__ + ValueType& dereference() const{ + return *(ptr_+offset_%vec_size_); + } + + // 实现递增操作符 + __host__ __device__ + void increment() { + offset_++; + } + + // 实现递减操作符 + __host__ __device__ + void decrement() { + offset_--; + } + + // 实现加法操作符 + __host__ __device__ + void advance(typename LoopVectorIterator::difference_type n) { + offset_+=n; + } + + // 实现减法操作符 + __host__ __device__ + typename LoopVectorIterator::difference_type distance_to(const LoopVectorIterator& other) const { + return (other.offset_ - offset_); + } + + // 实现比较操作符 + __host__ __device__ + bool equal(const LoopVectorIterator& other) const { + return offset_ == other.offset_; + } + + private: + ValueType* ptr_; + int vec_size_; + int offset_ = 0; + }; + + /** + * @brief 按照行元素逐个存取的迭代去 + * + * @tparam ValueType + */ + template + class RowWiseIterator:public thrust::iterator_facade< + RowWiseIterator, + ValueType, + thrust::device_system_tag, + thrust::random_access_traversal_tag, + ValueType& >{ + public: + + __host__ __device__ + RowWiseIterator(ValueType* ptr, int aRowElementCount=1,int aColElementCount=1, int aBaseOffset=0) + : ptr_(ptr), row_elements_(aRowElementCount),col_elements_(aColElementCount), offset_(aBaseOffset){} + + __host__ __device__ + ValueType& dereference() const{ + return *(ptr_ + (offset_ / row_elements_) + (offset_ % row_elements_)*col_elements_); + } + + // 实现递增操作符 + __host__ __device__ + void increment() { + offset_++; + } + + // 实现递减操作符 + __host__ __device__ + void decrement() { + offset_--; + } + + // 实现加法操作符 + __host__ __device__ + void advance(typename RowWiseIterator::difference_type n) { + offset_+=n; + } + + // 实现减法操作符 + __host__ __device__ + typename RowWiseIterator::difference_type distance_to(const RowWiseIterator& other) const { + return (other.offset_ - offset_); + } + + // 实现比较操作符 + __host__ __device__ + bool equal(const RowWiseIterator& other) const { + return (offset_ == other.offset_); + } + + private: + ValueType* ptr_; + int row_elements_; + int col_elements_; + int offset_ = 0; + }; + +}; +#endif // __AURORATHRUSTITERATOR_H__ \ No newline at end of file