Add thrust iterator inherit from iterator_facade
This commit is contained in:
132
src/AuroraThrustIterator.cuh
Normal file
132
src/AuroraThrustIterator.cuh
Normal file
@@ -0,0 +1,132 @@
|
||||
#ifndef __AURORATHRUSTITERATOR_H__
|
||||
#define __AURORATHRUSTITERATOR_H__
|
||||
#include <thrust/iterator/iterator_facade.h>
|
||||
namespace Aurora {
|
||||
|
||||
/**
|
||||
* @brief 将一个向量通过循环读取来充作矩阵迭代器
|
||||
*
|
||||
* @tparam ValueType
|
||||
*/
|
||||
template <typename ValueType>
|
||||
class LoopVectorIterator:public thrust::iterator_facade<
|
||||
LoopVectorIterator<ValueType>,
|
||||
ValueType,
|
||||
thrust::device_system_tag,
|
||||
thrust::random_access_traversal_tag,
|
||||
ValueType& >{
|
||||
public:
|
||||
/**
|
||||
* @brief 循环存取向量迭代器构造函数
|
||||
* @attention 特别注意本迭代器不会有结束值!
|
||||
*
|
||||
* @param ptr 原始指针
|
||||
* @param aVectorSize 向量长度
|
||||
*/
|
||||
__host__ __device__
|
||||
LoopVectorIterator(ValueType* ptr, int aVectorSize=1)
|
||||
: ptr_(ptr),vec_size_(aVectorSize){}
|
||||
|
||||
__host__ __device__
|
||||
ValueType& dereference() const{
|
||||
return *(ptr_+offset_%vec_size_);
|
||||
}
|
||||
|
||||
// 实现递增操作符
|
||||
__host__ __device__
|
||||
void increment() {
|
||||
offset_++;
|
||||
}
|
||||
|
||||
// 实现递减操作符
|
||||
__host__ __device__
|
||||
void decrement() {
|
||||
offset_--;
|
||||
}
|
||||
|
||||
// 实现加法操作符
|
||||
__host__ __device__
|
||||
void advance(typename LoopVectorIterator::difference_type n) {
|
||||
offset_+=n;
|
||||
}
|
||||
|
||||
// 实现减法操作符
|
||||
__host__ __device__
|
||||
typename LoopVectorIterator::difference_type distance_to(const LoopVectorIterator& other) const {
|
||||
return (other.offset_ - offset_);
|
||||
}
|
||||
|
||||
// 实现比较操作符
|
||||
__host__ __device__
|
||||
bool equal(const LoopVectorIterator& other) const {
|
||||
return offset_ == other.offset_;
|
||||
}
|
||||
|
||||
private:
|
||||
ValueType* ptr_;
|
||||
int vec_size_;
|
||||
int offset_ = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief 按照行元素逐个存取的迭代去
|
||||
*
|
||||
* @tparam ValueType
|
||||
*/
|
||||
template <typename ValueType>
|
||||
class RowWiseIterator:public thrust::iterator_facade<
|
||||
RowWiseIterator<ValueType>,
|
||||
ValueType,
|
||||
thrust::device_system_tag,
|
||||
thrust::random_access_traversal_tag,
|
||||
ValueType& >{
|
||||
public:
|
||||
|
||||
__host__ __device__
|
||||
RowWiseIterator(ValueType* ptr, int aRowElementCount=1,int aColElementCount=1, int aBaseOffset=0)
|
||||
: ptr_(ptr), row_elements_(aRowElementCount),col_elements_(aColElementCount), offset_(aBaseOffset){}
|
||||
|
||||
__host__ __device__
|
||||
ValueType& dereference() const{
|
||||
return *(ptr_ + (offset_ / row_elements_) + (offset_ % row_elements_)*col_elements_);
|
||||
}
|
||||
|
||||
// 实现递增操作符
|
||||
__host__ __device__
|
||||
void increment() {
|
||||
offset_++;
|
||||
}
|
||||
|
||||
// 实现递减操作符
|
||||
__host__ __device__
|
||||
void decrement() {
|
||||
offset_--;
|
||||
}
|
||||
|
||||
// 实现加法操作符
|
||||
__host__ __device__
|
||||
void advance(typename RowWiseIterator::difference_type n) {
|
||||
offset_+=n;
|
||||
}
|
||||
|
||||
// 实现减法操作符
|
||||
__host__ __device__
|
||||
typename RowWiseIterator::difference_type distance_to(const RowWiseIterator& other) const {
|
||||
return (other.offset_ - offset_);
|
||||
}
|
||||
|
||||
// 实现比较操作符
|
||||
__host__ __device__
|
||||
bool equal(const RowWiseIterator& other) const {
|
||||
return (offset_ == other.offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
ValueType* ptr_;
|
||||
int row_elements_;
|
||||
int col_elements_;
|
||||
int offset_ = 0;
|
||||
};
|
||||
|
||||
};
|
||||
#endif // __AURORATHRUSTITERATOR_H__
|
||||
Reference in New Issue
Block a user