#ifndef __AURORATHRUSTITERATOR_H__ #define __AURORATHRUSTITERATOR_H__ #include namespace Aurora { /** * @brief 将一个向量通过循环读取来充作矩阵迭代器 * * @tparam ValueType */ template class LoopVectorIterator:public thrust::iterator_facade< LoopVectorIterator, ValueType, thrust::device_system_tag, thrust::random_access_traversal_tag, ValueType& >{ public: /** * @brief 循环存取向量迭代器构造函数 * @attention 特别注意本迭代器不会有结束值! * * @param ptr 原始指针 * @param aVectorSize 向量长度 */ __host__ __device__ LoopVectorIterator(ValueType* ptr, int aVectorSize=1) : ptr_(ptr),vec_size_(aVectorSize){} __host__ __device__ ValueType& dereference() const{ return *(ptr_+offset_%vec_size_); } // 实现递增操作符 __host__ __device__ void increment() { offset_++; } // 实现递减操作符 __host__ __device__ void decrement() { offset_--; } // 实现加法操作符 __host__ __device__ void advance(typename LoopVectorIterator::difference_type n) { offset_+=n; } // 实现减法操作符 __host__ __device__ typename LoopVectorIterator::difference_type distance_to(const LoopVectorIterator& other) const { return (other.offset_ - offset_); } // 实现比较操作符 __host__ __device__ bool equal(const LoopVectorIterator& other) const { return offset_ == other.offset_; } private: ValueType* ptr_; int vec_size_; int offset_ = 0; }; /** * @brief 按照行元素逐个存取的迭代去 * * @tparam ValueType */ template class RowWiseIterator:public thrust::iterator_facade< RowWiseIterator, ValueType, thrust::device_system_tag, thrust::random_access_traversal_tag, ValueType& >{ public: __host__ __device__ RowWiseIterator(ValueType* ptr, int aRowElementCount=1,int aColElementCount=1, int aBaseOffset=0) : ptr_(ptr), row_elements_(aRowElementCount),col_elements_(aColElementCount), offset_(aBaseOffset){} __host__ __device__ ValueType& dereference() const{ return *(ptr_ + (offset_ / row_elements_) + (offset_ % row_elements_)*col_elements_); } // 实现递增操作符 __host__ __device__ void increment() { offset_++; } // 实现递减操作符 __host__ __device__ void decrement() { offset_--; } // 实现加法操作符 __host__ __device__ void advance(typename RowWiseIterator::difference_type n) { offset_+=n; } // 实现减法操作符 __host__ __device__ typename RowWiseIterator::difference_type distance_to(const RowWiseIterator& other) const { return (other.offset_ - offset_); } // 实现比较操作符 __host__ __device__ bool equal(const RowWiseIterator& other) const { return (offset_ == other.offset_); } private: ValueType* ptr_; int row_elements_; int col_elements_; int offset_ = 0; }; }; #endif // __AURORATHRUSTITERATOR_H__