参照:http://www.codeproject.com/Articles/543451/Parallel-Radix-Sort-on-the-GPU-using-Cplusplus-AMP
对于普通PC电脑而言,在数据量较小时,多线程优于GPU加速;数据量较大时,GPU加速优于多线程。
main.cpp
1 #include <amp.h> 2 #include <chrono> 3 #include <algorithm> 4 #include <conio.h> 5 #include "radix_sort.h" 6 #include <ppl.h> 7 8 9 int main() 10 { 11 using namespace concurrency; 12 accelerator default_device; 13 wprintf(L"Using device : %s\n\n", default_device.get_description()); 14 if (default_device == accelerator(accelerator::direct3d_ref)) 15 printf("WARNING!! Running on very slow emulator! Only use this accelerator for debugging.\n\n"); 16 17 for(uint i = 0; i < 10; i ++) 18 { 19 uint num = (1<<(i+10)); 20 printf("Testing for %u elements: \n", num); 21 22 std::vector<uint> data(num); 23 24 for(uint i = 0; i < num; i ++) 25 { 26 data[i] = i; 27 } 28 std::random_shuffle(data.begin(), data.end()); 29 std::vector<uint> dataclone(data.begin(), data.end()); 30 31 auto start_fill = std::chrono::high_resolution_clock::now(); 32 array<uint> av(num, data.begin(), data.end()); 33 auto end_fill = std::chrono::high_resolution_clock::now(); 34 35 printf("Allocating %u random unsigned integers complete! Start GPU sort.\n", num); 36 37 auto start_comp = std::chrono::high_resolution_clock::now(); 38 pal::radix_sort(av); 39 av.accelerator_view.wait(); //Wait for the computation to finish 40 auto end_comp = std::chrono::high_resolution_clock::now(); 41 42 auto start_collect = std::chrono::high_resolution_clock::now(); 43 data = av; //synchronise 44 auto end_collect = std::chrono::high_resolution_clock::now(); 45 46 printf("GPU sort completed in %llu microseconds.\nData transfer: %llu microseconds, computation: %llu microseconds\n", 47 std::chrono::duration_cast<std::chrono::microseconds> (end_collect-start_fill).count(), 48 std::chrono::duration_cast<std::chrono::microseconds> (end_fill-start_fill+end_collect-start_collect).count(), 49 std::chrono::duration_cast<std::chrono::microseconds> (end_comp-start_comp).count()); 50 51 printf("Testing for correctness. Results are.. "); 52 53 uint success = 1; 54 for(uint i = 0; i < num; i ++) 55 { 56 if(data[i] != i) { success = 0; break;} 57 } 58 printf("%s\n", (success? "correct!" : "incorrect!")); 59 60 data = dataclone; 61 printf("Beginning CPU sorts for comparison.\n"); 62 start_comp = std::chrono::high_resolution_clock::now(); 63 std::sort(data.data(), data.data()+num); 64 end_comp = std::chrono::high_resolution_clock::now(); 65 printf("CPU std::sort completed in %llu microseconds. \n", std::chrono::duration_cast<std::chrono::microseconds>(end_comp-start_comp).count()); 66 67 data = dataclone; 68 start_comp = std::chrono::high_resolution_clock::now(); 69 //Note: the concurrency::parallel sorts are horribly slow if you give them vectors (i.e. parallel_radixsort(data.begin(), data.end()) 70 concurrency::parallel_radixsort(data.data(), data.data()+num); 71 end_comp = std::chrono::high_resolution_clock::now(); 72 printf("CPU concurrency::parallel_sort completed in %llu microseconds. \n\n\n", std::chrono::duration_cast<std::chrono::microseconds>(end_comp-start_comp).count()); 73 74 } 75 76 printf("Press any key to exit! \n"); 77 _getch(); 78 }
radix_sort.h
1 # pragma once 2 typedef unsigned int uint; 3 #include <amp.h> 4 5 namespace pal 6 { 7 void radix_sort(uint* start, uint num); 8 void radix_sort(concurrency::array<uint>& arr); 9 }
readix_sort.cpp
1 #include <amp.h> 2 #include "radix_sort.h" 3 4 5 void arr_fill(concurrency::array_view<uint> &dest, concurrency::array_view<uint>& src, uint val) 6 { 7 parallel_for_each(dest.extent,[dest ,val, src](concurrency::index<1> idx)restrict(amp) 8 { 9 dest[idx] = ( (uint)idx[0] <src.get_extent().size())? src[idx]: val; 10 }); 11 } 12 13 uint get_bits(uint x, uint numbits, uint bitoffset) restrict(amp) 14 { 15 return (x>>bitoffset) & ~(~0 <<numbits); 16 } 17 18 uint pow2(uint x) restrict(amp,cpu) 19 { 20 return ( ((uint)1) << x); 21 } 22 23 uint tile_sum(uint x, concurrency::tiled_index<256> tidx) restrict(amp) 24 { 25 using namespace concurrency; 26 uint l_id = tidx.local[0]; 27 tile_static uint l_sums[256][2]; 28 29 l_sums[l_id][0] = x; 30 tidx.barrier.wait(); 31 32 for(uint i = 0; i < 8; i ++) 33 { 34 if(l_id< pow2(7-i)) 35 { 36 uint w = (i+1)%2; 37 uint r = i%2; 38 39 l_sums[l_id][w] = l_sums[l_id*2][r] + l_sums[l_id*2 +1][r]; 40 } 41 tidx.barrier.wait(); 42 } 43 return l_sums[0][0]; 44 45 } 46 47 uint tile_prefix_sum(uint x, concurrency::tiled_index<256> tidx, uint& last_val ) restrict(amp) 48 { 49 using namespace concurrency; 50 uint l_id = tidx.local[0]; 51 tile_static uint l_prefix_sums[256][2]; 52 53 l_prefix_sums[l_id][0] = x; 54 tidx.barrier.wait(); 55 56 for(uint i = 0; i < 8; i ++) 57 { 58 uint pow2i = pow2(i); 59 60 uint w = (i+1)%2; 61 uint r = i%2; 62 63 l_prefix_sums[l_id][w] = (l_id >= pow2i)? ( l_prefix_sums[l_id][r] + l_prefix_sums[l_id - pow2i][r]) : l_prefix_sums[l_id][r] ; 64 65 tidx.barrier.wait(); 66 } 67 last_val = l_prefix_sums[255][0]; 68 69 uint retval = (l_id ==0)? 0: l_prefix_sums[l_id -1][0]; 70 return retval; 71 } 72 73 uint tile_prefix_sum(uint x, concurrency::tiled_index<256> tidx) restrict(amp) 74 { 75 uint ll=0; 76 return tile_prefix_sum(x, tidx, ll); 77 } 78 79 80 void calc_interm_sums(uint bitoffset, concurrency::array<uint> & interm_arr, 81 concurrency::array<uint> & interm_sums, concurrency::array<uint> & interm_prefix_sums, uint num_tiles) 82 { 83 using namespace concurrency; 84 auto ext = extent<1>(num_tiles*256).tile<256>(); 85 86 parallel_for_each(ext , [=, &interm_sums, &interm_arr](tiled_index<256> tidx) restrict(amp) 87 { 88 uint inbound = ((uint)tidx.global[0]<interm_arr.get_extent().size()); 89 uint num = (inbound)? get_bits(interm_arr[tidx.global[0]], 2, bitoffset): get_bits(0xffffffff, 2, bitoffset); 90 for(uint i = 0; i < 4; i ++) 91 { 92 uint to_sum = (num == i); 93 uint sum = tile_sum(to_sum, tidx); 94 95 if(tidx.local[0] == 0) 96 { 97 interm_sums[i*num_tiles + tidx.tile[0]] = sum; 98 } 99 } 100 101 }); 102 103 uint numiter = (num_tiles/64) + ((num_tiles%64 == 0)? 0:1); 104 ext = extent<1>(256).tile<256>(); 105 parallel_for_each(ext , [=, &interm_prefix_sums, &interm_sums](tiled_index<256> tidx) restrict(amp) 106 { 107 uint last_val0 = 0; 108 uint last_val1 = 0; 109 110 for(uint i = 0; i < numiter; i ++) 111 { 112 uint g_id = tidx.local[0] + i*256; 113 uint num = (g_id<(num_tiles*4))? interm_sums[g_id]: 0; 114 uint scan = tile_prefix_sum(num, tidx, last_val0); 115 if(g_id<(num_tiles*4)) interm_prefix_sums[g_id] = scan + last_val1; 116 117 last_val1 += last_val0; 118 } 119 120 }); 121 } 122 123 void sort_step(uint bitoffset, concurrency::array<uint> & src, concurrency::array<uint> & dest, 124 concurrency::array<uint> & interm_prefix_sums, uint num_tiles) 125 { 126 using namespace concurrency; 127 auto ext = extent<1>(num_tiles*256).tile<256>(); 128 129 parallel_for_each(ext , [=, &interm_prefix_sums, &src, &dest](tiled_index<256> tidx) restrict(amp) 130 { 131 uint inbounds = ((uint)tidx.global[0]<src.get_extent().size()); 132 uint element = (inbounds)? src[tidx.global[0]] : 0xffffffff; 133 uint num = get_bits(element, 2,bitoffset); 134 for(uint i = 0; i < 4; i ++) 135 { 136 uint scan = tile_prefix_sum((num == i), tidx) + interm_prefix_sums[i*num_tiles + tidx.tile[0]]; 137 if(num==i && inbounds) dest[scan] = element; 138 } 139 140 }); 141 } 142 143 namespace pal 144 { 145 void radix_sort(concurrency::array<uint>& arr) 146 { 147 using namespace concurrency; 148 uint size = arr.get_extent().size(); 149 150 const uint num_tiles = (size/256) + ((size%256 == 0)? 0:1); 151 152 array<uint> interm_arr(size); 153 array<uint> interm_sums(num_tiles*4); 154 array<uint> interm_prefix_sums(num_tiles*4); 155 156 for(uint i = 0; i < 16; i ++) 157 { 158 array<uint>& src = (i%2==0)? arr: interm_arr; 159 array<uint>& dest = (i%2==0)? interm_arr: arr; 160 161 uint bitoffset = i*2; 162 calc_interm_sums(bitoffset, src, interm_sums, interm_prefix_sums, num_tiles); 163 sort_step(bitoffset, src, dest, interm_prefix_sums, num_tiles); 164 } 165 } 166 167 void radix_sort(uint* arr, uint size) 168 { 169 radix_sort(concurrency::array<uint>(size, arr)); 170 } 171 }
时间: 2024-10-05 11:36:33