前言
上一节介绍了 STL 中算法的分类以及泛化过程,这一节将介绍数值算法以及集合相关的算法,这部分算法包含在 STL 的 stl_numeric.h
以及 algo.h
中,在 TinySTL
中被放在 numeric.h
以及 set_algo.h
中。
numeric 算法
数值算法中主要包括四个算法:accumulate
、adjacent_difference
、inner_product
、partial_sum
,这些算法都不会改变处理区间的值,即都是非质变算法。
accumulate
accumulate(first, last, init)
的作用为对 [first, last)
内的元素进行以 init
为初值的累加,并返回求和的结果。
重载版本可以接收一个二元仿函数,以初值 init
对每个元素进行二元操作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
template <class InputIterator, class T> T accumulate(InputIterator first, InputIterator last, T init) { for (; first != last; ++first) { init += *first; } return init; }
template <class InputIterator, class T, class BinaryOperation> T accumulate(InputIterator first, InputIterator last, T init, BinaryOperation binary_op) { for (; first != last; ++first) { init = binary_op(init, *first); } return init; }
|
关于仿函数的内容将会放在下一章分析,这里先来看一下如何利用仿函数来改变 accumulate
的行为:
1 2
| vector<int> v = {1, 2, 3, 4, 5}; cout << accumulate(v.begin(), v.end(), 1, multiplies<int>()) << endl;
|
这里传给 accumulate
一个 multiplies<int>()
的仿函数,该仿函数为一个二元仿函数,作用为返回两个参数的乘积,因此以 1
为初值进行 accumulate
实际上就是求区间累乘的结果了。与此类似地,通过传入不同仿函数可以使得同一个函数展现出不同的行为,非常灵活。
adjacent_difference
adjacent_difference
用于计算相邻元素的差值,结果保存到以 result
为起始的区间上,同样也有一个接受二元仿函数的重载版本。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
|
template <class InputIterator, class OutputIterator> OutputIterator adjacent_difference(InputIterator first, InputIterator last, OutputIterator result) { if (first == last) return result; *result = *first; auto value = *first; while (++first != last) { auto tmp = *first; *++result = tmp - value; value = tmp; } return ++result; }
|
测试案例
1 2 3 4 5 6 7
| vector<int> v = {1, 2, 4, 8, 16}; vector<int> res(v.size()); adjacent_difference(v.begin(), v.end(), res.begin());
for (auto i : res) cout << i << " "; cout << endl;
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| template <class InputIterator, class OutputIterator, class BinaryOperation> OutputIterator adjacent_difference(InputIterator first, InputIterator last, OutputIterator result, BinaryOperation binary_op) { if (first == last) return result; *result = *first; auto value = *first; while (++first != last) { auto tmp = *first; *++result = binary_op(tmp, value); value = tmp; } return ++result; }
|
同样地,传入不同的仿函数也可以使 adjacent_difference
展现出不同的行为:
1 2 3 4 5 6
| vector<int> v = {1, 2, 4, 8, 16}; vector<int> res(v.size()); adjacent_difference(v.begin(), v.end(), res.begin(), plus<int>()); for (auto i : res) cout << i << " "; cout << endl;
|
这样就变成了求相邻元素之和。
inner_product
inner_product
处理两个区间的数据,返回两个区间的内积结果,同样也有一个重载版本。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
template <class InputIterator1, class InputIterator2, class T> T inner_product(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, T init) { for (; first1 != last1; ++first1, ++first2) { init = init + (*first1 * *first2); } return init; }
template <class InputIterator1, class InputIterator2, class T, class BinaryOperation1, class BinaryOperation2> T inner_product(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, T init, BinaryOperation1 binary_op1, BinaryOperation2 binary_op2) { for (; first1 != last1; ++first1, ++first2) { init = binary_op1(init, binary_op2(*first1, *first2)); } return init; }
|
需要注意的是这里接受的第二个区间就只有一个 first2
指针,并不接受 last2
,只以第一个区间的长度作为循环的次数。
partial_sum
partial_sum
的作用为计算的区间累计求和,结果保存到以 result
为起始的区间上,就是刷题里经常遇到的前缀和操作。同样提供了一个重载。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
template <class InputIterator, class OutputIterator> OutputIterator partial_sum(InputIterator first, InputIterator last, OutputIterator result) { if (first == last) return result; *result = *first; auto value = *first; while (++first != last) { value = value + *first; *++result = value; } return ++result; }
template <class InputIterator, class OutputIterator, class BinaryOperation> OutputIterator partial_sum(InputIterator first, InputIterator last, OutputIterator result, BinaryOperation binary_op) { if (first == last) return result; *result = *first; auto value = *first; while (++first != last) { value = binary_op(value, *first); *++result = value; } return ++result; }
|
set_algo
STL一共提供了四种与set(集合)相关的算法,分别是并集(union)、交集(intersection)、差集(difference)、对称差集( symmetric difference)。
值得注意的是,STL 中的 set 算法所接受的 set 必须是有序区间,因此可以接受基于红黑树的 set/multiset
等,但基于哈希表的 unordered_set/unordered_multiset
则不可以应用于这里的算法。
set_union
计算 S1∪S2 的结果并保存到 result
中,返回一个迭代器指向输出结果的尾部。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
template <class InputIterator1, class InputIterator2, class OutputIterator> OutputIterator set_union(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result) { while (first1 != last1 && first2 != last2) { if (*first1 < *first2) { *result = *first1; ++first1; } else if (*first2 < *first1) { *result = *first2; ++first2; } else { *result = *first1; ++first1; ++first2; } ++result; } return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result)); }
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare> OutputIterator set_union(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, Compare comp) { while (first1 != last1 && first2 != last2) { if (comp(*first1, *first2)) { *result = *first1; ++first1; } else if (comp(*first2, *first1)) { *result = *first2; ++first2; } else { *result = *first1; ++first1; ++first2; } ++result; } return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result)); }
|
可以看出,set_union
的具体执行逻辑就是类似于有序链表合并的逻辑,因此输入的两个区间一定都要是有序的。
set_intersection
计算 S1∩S2 的结果并保存到 result
中,返回一个迭代器指向输出结果的尾部。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
|
template <class InputIterator1, class InputIterator2, class OutputIterator> OutputIterator set_intersection(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result) { while (first1 != last1 && first2 != last2) { if (*first1 < *first2) { ++first1; } else if (*first2 < *first1) { ++first2; } else { *result = *first1; ++first1; ++first2; ++result; } } return result; }
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare> OutputIterator set_intersection(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, Compare comp) { while (first1 != last1 && first2 != last2) { if (comp(*first1, *first2)) { ++first1; } else if (comp(*first2, *first1)) { ++first2; } else { *result = *first1; ++first1; ++first2; ++result; } } return result; }
|
set_difference
计算 S1−S2 的结果并保存到 result
中,返回一个迭代器指向输出结果的尾部。所谓 S1−S2 指的就是计算 S1 中具有而 S2 中不具有的元素。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
template <class InputIterator1, class InputIterator2, class OutputIterator> OutputIterator set_difference(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result) { while (first1 != last1 && first2 != last2) { if (*first1 < *first2) { *result = *first1; ++first1; ++result; } else if (*first2 < *first1) { ++first2; } else { ++first1; ++first2; } } return tinystl::copy(first1, last1, result); }
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare> OutputIterator set_difference(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, Compare comp) { while (first1 != last1 && first2 != last2) { if (comp(*first1, *first2)) { *result = *first1; ++first1; ++result; } else if (comp(*first2, *first1)) { ++first2; } else { ++first1; ++first2; } } return tinystl::copy(first1, last1, result); }
|
以下是一个例子,关于 insert_iterator
会在适配器那一章说明。
1 2 3 4 5 6 7
| set<int> s1 = {1, 2, 3, 4, 5}; set<int> s2 = {3, 4, 5, 6, 7}; set<int> res; set_difference(s1.begin(), s1.end(), s2.begin(), s2.end(), inserter(res, res.begin())); for (auto i : res) cout << i << " "; cout << endl;
|
set_symmetric_difference
计算 (S1−S2)∪(S2−S1) 的结果并保存到 result
中,返回一个迭代器指向输出结果的尾部。这个函数计算的是对称差集,效果就相当于两集合取并集再减去交集。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
|
template <class InputIterator1, class InputIterator2, class OutputIterator> OutputIterator set_symmetric_difference(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result) { while (first1 != last1 && first2 != last2) { if (*first1 < *first2) { *result = *first1; ++first1; ++result; } else if (*first2 < *first1) { *result = *first2; ++first2; ++result; } else { ++first1; ++first2; } } return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result)); }
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare> OutputIterator set_symmetric_difference(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2, OutputIterator result, Compare comp) { while (first1 != last1 && first2 != last2) { if (comp(*first1, *first2)) { *result = *first1; ++first1; ++result; } else if (comp(*first2, *first1)) { *result = *first2; ++first2; ++result; } else { ++first1; ++first2; } } return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result)); }
|
以下是一个例子,关于 insert_iterator
会在适配器那一章说明。
1 2 3 4 5 6 7
| set<int> s1 = {1, 2, 3, 4, 5}; set<int> s2 = {3, 4, 5, 6, 7}; set<int> res; set_symmetric_difference(s1.begin(), s1.end(), s2.begin(), s2.end(), inserter(res, res.begin())); for (auto i : res) cout << i << " "; cout << endl;
|
总结
本节简要介绍了 STL 中的数值算法以及集合的相关算法,由参数的型别可以看出,这些算法接受的都是最低的 InputIterator
以及 OutputIterator
,并未对迭代器做更高的限制,并且需要特别注意的是,集合算法只能处理有序的区间。