Files
Aurora/src/AuroraThrustIterator.cuh
2023-12-13 16:40:29 +08:00

132 lines
3.7 KiB
Plaintext

#ifndef __AURORATHRUSTITERATOR_H__
#define __AURORATHRUSTITERATOR_H__
#include <thrust/iterator/iterator_facade.h>
namespace Aurora {
/**
* @brief 将一个向量通过循环读取来充作矩阵迭代器
*
* @tparam ValueType
*/
template <typename ValueType>
class LoopVectorIterator:public thrust::iterator_facade<
LoopVectorIterator<ValueType>,
ValueType,
thrust::device_system_tag,
thrust::random_access_traversal_tag,
ValueType& >{
public:
/**
* @brief 循环存取向量迭代器构造函数
* @attention 特别注意本迭代器不会有结束值!
*
* @param ptr 原始指针
* @param aVectorSize 向量长度
*/
__host__ __device__
LoopVectorIterator(ValueType* ptr, int aVectorSize=1)
: ptr_(ptr),vec_size_(aVectorSize){}
__host__ __device__
ValueType& dereference() const{
return *(ptr_+offset_%vec_size_);
}
// 实现递增操作符
__host__ __device__
void increment() {
offset_++;
}
// 实现递减操作符
__host__ __device__
void decrement() {
offset_--;
}
// 实现加法操作符
__host__ __device__
void advance(typename LoopVectorIterator::difference_type n) {
offset_+=n;
}
// 实现减法操作符
__host__ __device__
typename LoopVectorIterator::difference_type distance_to(const LoopVectorIterator& other) const {
return (other.offset_ - offset_);
}
// 实现比较操作符
__host__ __device__
bool equal(const LoopVectorIterator& other) const {
return offset_ == other.offset_;
}
private:
ValueType* ptr_;
int vec_size_;
int offset_ = 0;
};
/**
* @brief 按照行元素逐个存取的迭代去
*
* @tparam ValueType
*/
template <typename ValueType>
class RowWiseIterator:public thrust::iterator_facade<
RowWiseIterator<ValueType>,
ValueType,
thrust::device_system_tag,
thrust::random_access_traversal_tag,
ValueType& >{
public:
__host__ __device__
RowWiseIterator(ValueType* ptr, int aRowElementCount=1,int aColElementCount=1, int aBaseOffset=0)
: ptr_(ptr), row_elements_(aRowElementCount),col_elements_(aColElementCount), offset_(aBaseOffset){}
__host__ __device__
ValueType& dereference() const{
return *(ptr_ + (offset_ / row_elements_) + (offset_ % row_elements_)*col_elements_);
}
// 实现递增操作符
__host__ __device__
void increment() {
offset_++;
}
// 实现递减操作符
__host__ __device__
void decrement() {
offset_--;
}
// 实现加法操作符
__host__ __device__
void advance(typename RowWiseIterator::difference_type n) {
offset_+=n;
}
// 实现减法操作符
__host__ __device__
typename RowWiseIterator::difference_type distance_to(const RowWiseIterator& other) const {
return (other.offset_ - offset_);
}
// 实现比较操作符
__host__ __device__
bool equal(const RowWiseIterator& other) const {
return (offset_ == other.offset_);
}
private:
ValueType* ptr_;
int row_elements_;
int col_elements_;
int offset_ = 0;
};
};
#endif // __AURORATHRUSTITERATOR_H__