1 /********************************************************************
2 created: 2014/04/29 11:35
3 filename: nth_element.cpp
4 author: Justme0 (http://blog.csdn.net/justme0)
5
6 purpose: nth_element
7 *********************************************************************/
8
9 #include <cstdio>
10 #include <cstdlib>
11 #include <cstring>
12
13 typedef int Type;
14
15 template <class T>
16 inline T * copy_backward(const T *first, const T *last, T *result) {
17 const ptrdiff_t num = last - first;
18 memmove(result - num, first, sizeof(T) * num);
19 return result - num;
20 }
21
22 /*
23 ** 将 value 插到 last 前面(不包括 last)的区间
24 ** 此函数保证不会越界(主调函数已判断),因此以 unguarded_ 开头
25 */
26 template <class RandomAccessIterator, class T>
27 void unguarded_linear_insert(RandomAccessIterator last, T value) {
28 RandomAccessIterator next = last;
29 --next;
30 while(value < *next) {
31 *last = *next;
32 last = next;
33 --next;
34 }
35 *last = value;
36 }
37
38 /*
39 ** 将 last 处的元素插到[first, last)的有序区间
40 */
41 template <class RandomAccessIterator>
42 void linear_insert(RandomAccessIterator first, RandomAccessIterator last) {
43 Type value = *last;
44 if (value < *first) { // 若尾比头小,就将整个区间一次性向后移动一个位置
45 copy_backward(first, last, last + 1);
46 *first = value;
47 } else {
48 unguarded_linear_insert(last, value);
49 }
50 }
51
52 template <class RandomAccessIterator>
53 void insertion_sort(RandomAccessIterator first, RandomAccessIterator last) {
54 if (first == last) {
55 return ;
56 }
57
58 for (RandomAccessIterator ite = first + 1; ite != last; ++ite) {
59 linear_insert(first, ite);
60 }
61 }
62
63 template <class T>
64 inline const T & median(const T &a, const T &b, const T&c) {
65 if (a < b) {
66 if (b < c) {
67 return b;
68 } else if (a < c) {
69 return c;
70 } else {
71 return a;
72 }
73 } else if (a < c) {
74 return a;
75 } else if (b < c) {
76 return c;
77 } else {
78 return b;
79 }
80 }
81
82 template <class ForwardIterator1, class ForwardIterator2>
83 inline void iter_swap(ForwardIterator1 a, ForwardIterator2 b) {
84 Type tmp = *a; // 源码中的 T 由迭代器的 traits 得来,这里简化了
85 *a = *b;
86 *b = tmp;
87 }
88
89 /*
90 ** 设返回值为 mid,则[first, mid)中迭代器指向的值小于等于 pivot;
91 ** [mid, last)中迭代器指向的值大于等于 pivot
92 ** 这是 STL 内置的算法,会用于 nth_element, sort 中
93 ** 笔者很困惑为什么不用 partition
94 */
95 template <class RandomAccessIterator, class T>
96 RandomAccessIterator unguarded_partition(RandomAccessIterator first, RandomAccessIterator last, T pivot) {
97 while(true) {
98 while (*first < pivot) {
99 ++first;
100 }
101 --last;
102 while (pivot < *last) { // 若 std::partition 的 pred 是 IsLess(pivot),这里将是小于等于
103 --last;
104 }
105 if (!(first < last)) { // 小于操作只适用于 random access iterator
106 return first;
107 }
108 iter_swap(first, last);
109 ++first;
110 }
111 }
112
113 template <class RandomAccessIterator>
114 void nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last) {
115 while (last - first > 3) {
116 RandomAccessIterator cut = unguarded_partition(first, last, Type(median(
117 *first,
118 *(first + (last - first) / 2),
119 *(last - 1))));
120 if (cut <= nth) {
121 first = cut;
122 } else {
123 last = cut;
124 }
125 }
126 insertion_sort(first, last);
127 }
128
129
130 int main(int argc, char **argv) {
131 int arr[] = {22, 30, 30, 17, 33, 40, 17, 23, 22, 12, 20};
132 int size = sizeof arr / sizeof *arr;
133
134 nth_element(arr, arr + 5, arr + size);
135
136 for (int i = 0; i < size; ++i) {
137 printf("%d ", arr[i]); // 20 12 22 17 17 22 23 30 30 33 40
138 }
139 printf("\n");
140
141 system("PAUSE");
142 return 0;
143 }
nth_element 测试程序,码迷,mamicode.com
时间: 2024-11-07 23:50:24