Add thrust iterator inherit from iterator_facade
This commit is contained in:
132
src/AuroraThrustIterator.cuh
Normal file
132
src/AuroraThrustIterator.cuh
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
#ifndef __AURORATHRUSTITERATOR_H__
|
||||||
|
#define __AURORATHRUSTITERATOR_H__
|
||||||
|
#include <thrust/iterator/iterator_facade.h>
|
||||||
|
namespace Aurora {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 将一个向量通过循环读取来充作矩阵迭代器
|
||||||
|
*
|
||||||
|
* @tparam ValueType
|
||||||
|
*/
|
||||||
|
template <typename ValueType>
|
||||||
|
class LoopVectorIterator:public thrust::iterator_facade<
|
||||||
|
LoopVectorIterator<ValueType>,
|
||||||
|
ValueType,
|
||||||
|
thrust::device_system_tag,
|
||||||
|
thrust::random_access_traversal_tag,
|
||||||
|
ValueType& >{
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief 循环存取向量迭代器构造函数
|
||||||
|
* @attention 特别注意本迭代器不会有结束值!
|
||||||
|
*
|
||||||
|
* @param ptr 原始指针
|
||||||
|
* @param aVectorSize 向量长度
|
||||||
|
*/
|
||||||
|
__host__ __device__
|
||||||
|
LoopVectorIterator(ValueType* ptr, int aVectorSize=1)
|
||||||
|
: ptr_(ptr),vec_size_(aVectorSize){}
|
||||||
|
|
||||||
|
__host__ __device__
|
||||||
|
ValueType& dereference() const{
|
||||||
|
return *(ptr_+offset_%vec_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现递增操作符
|
||||||
|
__host__ __device__
|
||||||
|
void increment() {
|
||||||
|
offset_++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现递减操作符
|
||||||
|
__host__ __device__
|
||||||
|
void decrement() {
|
||||||
|
offset_--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现加法操作符
|
||||||
|
__host__ __device__
|
||||||
|
void advance(typename LoopVectorIterator::difference_type n) {
|
||||||
|
offset_+=n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现减法操作符
|
||||||
|
__host__ __device__
|
||||||
|
typename LoopVectorIterator::difference_type distance_to(const LoopVectorIterator& other) const {
|
||||||
|
return (other.offset_ - offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现比较操作符
|
||||||
|
__host__ __device__
|
||||||
|
bool equal(const LoopVectorIterator& other) const {
|
||||||
|
return offset_ == other.offset_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ValueType* ptr_;
|
||||||
|
int vec_size_;
|
||||||
|
int offset_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief 按照行元素逐个存取的迭代去
|
||||||
|
*
|
||||||
|
* @tparam ValueType
|
||||||
|
*/
|
||||||
|
template <typename ValueType>
|
||||||
|
class RowWiseIterator:public thrust::iterator_facade<
|
||||||
|
RowWiseIterator<ValueType>,
|
||||||
|
ValueType,
|
||||||
|
thrust::device_system_tag,
|
||||||
|
thrust::random_access_traversal_tag,
|
||||||
|
ValueType& >{
|
||||||
|
public:
|
||||||
|
|
||||||
|
__host__ __device__
|
||||||
|
RowWiseIterator(ValueType* ptr, int aRowElementCount=1,int aColElementCount=1, int aBaseOffset=0)
|
||||||
|
: ptr_(ptr), row_elements_(aRowElementCount),col_elements_(aColElementCount), offset_(aBaseOffset){}
|
||||||
|
|
||||||
|
__host__ __device__
|
||||||
|
ValueType& dereference() const{
|
||||||
|
return *(ptr_ + (offset_ / row_elements_) + (offset_ % row_elements_)*col_elements_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现递增操作符
|
||||||
|
__host__ __device__
|
||||||
|
void increment() {
|
||||||
|
offset_++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现递减操作符
|
||||||
|
__host__ __device__
|
||||||
|
void decrement() {
|
||||||
|
offset_--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现加法操作符
|
||||||
|
__host__ __device__
|
||||||
|
void advance(typename RowWiseIterator::difference_type n) {
|
||||||
|
offset_+=n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现减法操作符
|
||||||
|
__host__ __device__
|
||||||
|
typename RowWiseIterator::difference_type distance_to(const RowWiseIterator& other) const {
|
||||||
|
return (other.offset_ - offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现比较操作符
|
||||||
|
__host__ __device__
|
||||||
|
bool equal(const RowWiseIterator& other) const {
|
||||||
|
return (offset_ == other.offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ValueType* ptr_;
|
||||||
|
int row_elements_;
|
||||||
|
int col_elements_;
|
||||||
|
int offset_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
#endif // __AURORATHRUSTITERATOR_H__
|
||||||
Reference in New Issue
Block a user