132 lines
3.7 KiB
Plaintext
132 lines
3.7 KiB
Plaintext
#ifndef __AURORATHRUSTITERATOR_H__
|
|
#define __AURORATHRUSTITERATOR_H__
|
|
#include <thrust/iterator/iterator_facade.h>
|
|
namespace Aurora {
|
|
|
|
/**
|
|
* @brief 将一个向量通过循环读取来充作矩阵迭代器
|
|
*
|
|
* @tparam ValueType
|
|
*/
|
|
template <typename ValueType>
|
|
class LoopVectorIterator:public thrust::iterator_facade<
|
|
LoopVectorIterator<ValueType>,
|
|
ValueType,
|
|
thrust::device_system_tag,
|
|
thrust::random_access_traversal_tag,
|
|
ValueType& >{
|
|
public:
|
|
/**
|
|
* @brief 循环存取向量迭代器构造函数
|
|
* @attention 特别注意本迭代器不会有结束值!
|
|
*
|
|
* @param ptr 原始指针
|
|
* @param aVectorSize 向量长度
|
|
*/
|
|
__host__ __device__
|
|
LoopVectorIterator(ValueType* ptr, int aVectorSize=1)
|
|
: ptr_(ptr),vec_size_(aVectorSize){}
|
|
|
|
__host__ __device__
|
|
ValueType& dereference() const{
|
|
return *(ptr_+offset_%vec_size_);
|
|
}
|
|
|
|
// 实现递增操作符
|
|
__host__ __device__
|
|
void increment() {
|
|
offset_++;
|
|
}
|
|
|
|
// 实现递减操作符
|
|
__host__ __device__
|
|
void decrement() {
|
|
offset_--;
|
|
}
|
|
|
|
// 实现加法操作符
|
|
__host__ __device__
|
|
void advance(typename LoopVectorIterator::difference_type n) {
|
|
offset_+=n;
|
|
}
|
|
|
|
// 实现减法操作符
|
|
__host__ __device__
|
|
typename LoopVectorIterator::difference_type distance_to(const LoopVectorIterator& other) const {
|
|
return (other.offset_ - offset_);
|
|
}
|
|
|
|
// 实现比较操作符
|
|
__host__ __device__
|
|
bool equal(const LoopVectorIterator& other) const {
|
|
return offset_ == other.offset_;
|
|
}
|
|
|
|
private:
|
|
ValueType* ptr_;
|
|
int vec_size_;
|
|
int offset_ = 0;
|
|
};
|
|
|
|
/**
|
|
* @brief 按照行元素逐个存取的迭代去
|
|
*
|
|
* @tparam ValueType
|
|
*/
|
|
template <typename ValueType>
|
|
class RowWiseIterator:public thrust::iterator_facade<
|
|
RowWiseIterator<ValueType>,
|
|
ValueType,
|
|
thrust::device_system_tag,
|
|
thrust::random_access_traversal_tag,
|
|
ValueType& >{
|
|
public:
|
|
|
|
__host__ __device__
|
|
RowWiseIterator(ValueType* ptr, int aRowElementCount=1,int aColElementCount=1, int aBaseOffset=0)
|
|
: ptr_(ptr), row_elements_(aRowElementCount),col_elements_(aColElementCount), offset_(aBaseOffset){}
|
|
|
|
__host__ __device__
|
|
ValueType& dereference() const{
|
|
return *(ptr_ + (offset_ / row_elements_) + (offset_ % row_elements_)*col_elements_);
|
|
}
|
|
|
|
// 实现递增操作符
|
|
__host__ __device__
|
|
void increment() {
|
|
offset_++;
|
|
}
|
|
|
|
// 实现递减操作符
|
|
__host__ __device__
|
|
void decrement() {
|
|
offset_--;
|
|
}
|
|
|
|
// 实现加法操作符
|
|
__host__ __device__
|
|
void advance(typename RowWiseIterator::difference_type n) {
|
|
offset_+=n;
|
|
}
|
|
|
|
// 实现减法操作符
|
|
__host__ __device__
|
|
typename RowWiseIterator::difference_type distance_to(const RowWiseIterator& other) const {
|
|
return (other.offset_ - offset_);
|
|
}
|
|
|
|
// 实现比较操作符
|
|
__host__ __device__
|
|
bool equal(const RowWiseIterator& other) const {
|
|
return (offset_ == other.offset_);
|
|
}
|
|
|
|
private:
|
|
ValueType* ptr_;
|
|
int row_elements_;
|
|
int col_elements_;
|
|
int offset_ = 0;
|
|
};
|
|
|
|
};
|
|
#endif // __AURORATHRUSTITERATOR_H__ |