前言

上一节介绍了 STL 中算法的分类以及泛化过程,这一节将介绍数值算法以及集合相关的算法,这部分算法包含在 STL 的 stl_numeric.h 以及 algo.h 中,在 TinySTL 中被放在 numeric.h 以及 set_algo.h 中。

numeric 算法

数值算法中主要包括四个算法:accumulateadjacent_differenceinner_productpartial_sum,这些算法都不会改变处理区间的值,即都是非质变算法。

accumulate

accumulate(first, last, init) 的作用为对 [first, last) 内的元素进行以 init 为初值的累加,并返回求和的结果。
重载版本可以接收一个二元仿函数,以初值 init 对每个元素进行二元操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*****************************************************************************************/
// accumulate
// 版本1:以初值 init 对每个元素进行累加
// 版本2:以初值 init 对每个元素进行二元操作
/*****************************************************************************************/

/// @brief 版本1:以初值 init 对每个元素进行累加
template <class InputIterator, class T>
T accumulate(InputIterator first, InputIterator last, T init) {
for (; first != last; ++first) {
init += *first;
}
return init;
}

/// @brief 版本2:以初值 init 对每个元素进行二元操作
template <class InputIterator, class T, class BinaryOperation>
T accumulate(InputIterator first, InputIterator last, T init, BinaryOperation binary_op) {
for (; first != last; ++first) {
init = binary_op(init, *first);
}
return init;
}

关于仿函数的内容将会放在下一章分析,这里先来看一下如何利用仿函数来改变 accumulate 的行为:

1
2
vector<int> v = {1, 2, 3, 4, 5};
cout << accumulate(v.begin(), v.end(), 1, multiplies<int>()) << endl; // 120

这里传给 accumulate 一个 multiplies<int>() 的仿函数,该仿函数为一个二元仿函数,作用为返回两个参数的乘积,因此以 1 为初值进行 accumulate 实际上就是求区间累乘的结果了。与此类似地,通过传入不同仿函数可以使得同一个函数展现出不同的行为,非常灵活。

adjacent_difference

adjacent_difference 用于计算相邻元素的差值,结果保存到以 result 为起始的区间上,同样也有一个接受二元仿函数的重载版本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*****************************************************************************************/
// adjacent_difference
// 版本1:计算相邻元素的差值,结果保存到以 result 为起始的区间上
// 版本2:自定义相邻元素的二元操作
/*****************************************************************************************/

/// @brief 版本1:计算相邻元素的差值,结果保存到以 result 为起始的区间上
template <class InputIterator, class OutputIterator>
OutputIterator adjacent_difference(InputIterator first,
InputIterator last, OutputIterator result) {
if (first == last) return result;
*result = *first; // 记录第一个元素
auto value = *first;
while (++first != last) {
auto tmp = *first;
*++result = tmp - value;
value = tmp;
}
// 符合前闭后开的规则
return ++result;
}

测试案例

1
2
3
4
5
6
7
vector<int> v = {1, 2, 4, 8, 16};
vector<int> res(v.size());
adjacent_difference(v.begin(), v.end(), res.begin());

for (auto i : res)
cout << i << " ";
cout << endl; // 1 1 2 4 8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/// @brief 版本2:自定义相邻元素的二元操作
template <class InputIterator, class OutputIterator, class BinaryOperation>
OutputIterator adjacent_difference(InputIterator first,
InputIterator last, OutputIterator result, BinaryOperation binary_op) {
if (first == last) return result;
*result = *first; // 记录第一个元素
auto value = *first;
while (++first != last) {
auto tmp = *first;
*++result = binary_op(tmp, value);
value = tmp;
}
// 符合前闭后开的规则
return ++result;
}

同样地,传入不同的仿函数也可以使 adjacent_difference 展现出不同的行为:

1
2
3
4
5
6
vector<int> v = {1, 2, 4, 8, 16};
vector<int> res(v.size());
adjacent_difference(v.begin(), v.end(), res.begin(), plus<int>());
for (auto i : res)
cout << i << " ";
cout << endl; // 1 3 6 12 24

这样就变成了求相邻元素之和。

inner_product

inner_product 处理两个区间的数据,返回两个区间的内积结果,同样也有一个重载版本。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*****************************************************************************************/
// inner_product
// 版本1:以 init 为初值,计算两个区间的内积
// 版本2:自定义 operator+ 和 operator*
/*****************************************************************************************/

/// @brief 版本1:以 init 为初值,计算两个区间的内积
template <class InputIterator1, class InputIterator2, class T>
T inner_product(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, T init) {
for (; first1 != last1; ++first1, ++first2) {
init = init + (*first1 * *first2);
}
return init;
}

/// @brief 版本2:自定义 operator+ 和 operator*
template <class InputIterator1, class InputIterator2, class T, class BinaryOperation1, class BinaryOperation2>
T inner_product(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, T init, BinaryOperation1 binary_op1, BinaryOperation2 binary_op2) {
for (; first1 != last1; ++first1, ++first2) {
init = binary_op1(init, binary_op2(*first1, *first2));
}
return init;
}

需要注意的是这里接受的第二个区间就只有一个 first2 指针,并不接受 last2,只以第一个区间的长度作为循环的次数。

partial_sum

partial_sum 的作用为计算的区间累计求和,结果保存到以 result 为起始的区间上,就是刷题里经常遇到的前缀和操作。同样提供了一个重载。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/*****************************************************************************************/
// partial_sum
// 即前缀和
// 版本1:计算局部累计求和,结果保存到以 result 为起始的区间上
// 版本2:进行局部进行自定义二元操作
/*****************************************************************************************/

/// @brief 版本1:计算局部累计求和,结果保存到以 result 为起始的区间上
template <class InputIterator, class OutputIterator>
OutputIterator partial_sum(InputIterator first, InputIterator last, OutputIterator result) {
if (first == last) return result;
*result = *first; // 记录第一个元素
auto value = *first;
while (++first != last) {
value = value + *first;
*++result = value;
}
return ++result;
}

/// @brief 版本2:进行局部进行自定义二元操作
template <class InputIterator, class OutputIterator, class BinaryOperation>
OutputIterator partial_sum(InputIterator first, InputIterator last,
OutputIterator result, BinaryOperation binary_op) {
if (first == last) return result;
*result = *first; // 记录第一个元素
auto value = *first;
while (++first != last) {
value = binary_op(value, *first);
*++result = value;
}
return ++result;
}

set_algo

STL一共提供了四种与set(集合)相关的算法,分别是并集(union)、交集(intersection)、差集(difference)、对称差集( symmetric difference)。

值得注意的是,STL 中的 set 算法所接受的 set 必须是有序区间,因此可以接受基于红黑树的 set/multiset 等,但基于哈希表的 unordered_set/unordered_multiset 则不可以应用于这里的算法。

set_union

计算 S1S2S_1 \cup S_2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/*****************************************************************************************/
// set_union
// 计算 S1 ∪ S2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
/*****************************************************************************************/

/// @brief 计算 S1 ∪ S2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
template <class InputIterator1, class InputIterator2, class OutputIterator>
OutputIterator set_union(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result) {
// 过程类似于合并两个有序链表的做法
// 两个序列必须有序
while (first1 != last1 && first2 != last2) {
if (*first1 < *first2) {
*result = *first1;
++first1;
}
else if (*first2 < *first1) {
*result = *first2;
++first2;
}
else {
*result = *first1;
++first1;
++first2;
}
++result;
}
// 将剩余元素拷贝到 result 中
return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result));
}

/// @brief 重载版本使用函数对象 comp 代替比较操作
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare>
OutputIterator set_union(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result, Compare comp) {
while (first1 != last1 && first2 != last2) {
if (comp(*first1, *first2)) {
*result = *first1;
++first1;
}
else if (comp(*first2, *first1)) {
*result = *first2;
++first2;
}
else {
*result = *first1;
++first1;
++first2;
}
++result;
}
// 将剩余元素拷贝到 result 中
return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result));
}

可以看出,set_union 的具体执行逻辑就是类似于有序链表合并的逻辑,因此输入的两个区间一定都要是有序的。

set_intersection

计算 S1S2S_1 \cap S_2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/*****************************************************************************************/
// set_intersection
// 计算 S1 ∩ S2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
/*****************************************************************************************/

/// @brief 计算 S1 ∩ S2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
template <class InputIterator1, class InputIterator2, class OutputIterator>
OutputIterator set_intersection(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result) {
while (first1 != last1 && first2 != last2) {
if (*first1 < *first2) {
++first1;
}
else if (*first2 < *first1) {
++first2;
}
else {
*result = *first1;
++first1;
++first2;
++result;
}
}
return result;
}

/// @brief 重载版本使用函数对象 comp 代替比较操作
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare>
OutputIterator set_intersection(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result, Compare comp) {
while (first1 != last1 && first2 != last2) {
if (comp(*first1, *first2)) {
++first1;
}
else if (comp(*first2, *first1)) {
++first2;
}
else {
*result = *first1;
++first1;
++first2;
++result;
}
}
return result;
}

set_difference

计算 S1S2S_1-S_2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部。所谓 S1S2S_1-S_2 指的就是计算 S1S_1 中具有而 S2S_2 中不具有的元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/*****************************************************************************************/
// set_difference
// 计算 S1-S2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
/*****************************************************************************************/

/// @brief 计算 S1-S2 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
template <class InputIterator1, class InputIterator2, class OutputIterator>
OutputIterator set_difference(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result) {
while (first1 != last1 && first2 != last2) {
// 只有 *first1 < *first2 才能确保这个元素在 S1 中有而 S2 中没有
if (*first1 < *first2) {
*result = *first1;
++first1;
++result;
}
// 否则与 S2 的下一个元素继续比较
else if (*first2 < *first1) {
++first2;
}
// 相等则说明这个元素在 S1 和 S2 中都有,不是 S1-S2 的元素
else {
++first1;
++first2;
}
}
// 剩下的部分都是 S1 中有而 S2 中没有的元素
return tinystl::copy(first1, last1, result);
}

/// @brief 重载版本使用函数对象 comp 代替比较操作
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare>
OutputIterator set_difference(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result, Compare comp) {
while (first1 != last1 && first2 != last2) {
if (comp(*first1, *first2)) {
*result = *first1;
++first1;
++result;
}
else if (comp(*first2, *first1)) {
++first2;
}
else {
++first1;
++first2;
}
}
return tinystl::copy(first1, last1, result);
}

以下是一个例子,关于 insert_iterator 会在适配器那一章说明。

1
2
3
4
5
6
7
set<int> s1 = {1, 2, 3, 4, 5};
set<int> s2 = {3, 4, 5, 6, 7};
set<int> res;
set_difference(s1.begin(), s1.end(), s2.begin(), s2.end(), inserter(res, res.begin()));
for (auto i : res)
cout << i << " ";
cout << endl; // 1 2

set_symmetric_difference

计算 (S1S2)(S2S1)(S_1-S_2) \cup (S_2-S_1) 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部。这个函数计算的是对称差集,效果就相当于两集合取并集再减去交集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/*****************************************************************************************/
// set_symmetric_difference
// 等同于并集减去交集
// 计算 (S1-S2)∪(S2-S1) 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
/*****************************************************************************************/

/// @brief 计算 (S1-S2)∪(S2-S1) 的结果并保存到 result 中,返回一个迭代器指向输出结果的尾部
template <class InputIterator1, class InputIterator2, class OutputIterator>
OutputIterator set_symmetric_difference(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result) {
while (first1 != last1 && first2 != last2) {
if (*first1 < *first2) {
*result = *first1;
++first1;
++result;
}
else if (*first2 < *first1) {
*result = *first2;
++first2;
++result;
}
else {
++first1;
++first2;
}
}
// 剩下的部分是集合中独特的元素
return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result));
}

/// @brief 重载版本使用函数对象 comp 代替比较操作
template <class InputIterator1, class InputIterator2, class OutputIterator, class Compare>
OutputIterator set_symmetric_difference(InputIterator1 first1, InputIterator1 last1,
InputIterator2 first2, InputIterator2 last2,
OutputIterator result, Compare comp) {
while (first1 != last1 && first2 != last2) {
if (comp(*first1, *first2)) {
*result = *first1;
++first1;
++result;
}
else if (comp(*first2, *first1)) {
*result = *first2;
++first2;
++result;
}
else {
++first1;
++first2;
}
}
return tinystl::copy(first2, last2, tinystl::copy(first1, last1, result));
}

以下是一个例子,关于 insert_iterator 会在适配器那一章说明。

1
2
3
4
5
6
7
set<int> s1 = {1, 2, 3, 4, 5};
set<int> s2 = {3, 4, 5, 6, 7};
set<int> res;
set_symmetric_difference(s1.begin(), s1.end(), s2.begin(), s2.end(), inserter(res, res.begin()));
for (auto i : res)
cout << i << " ";
cout << endl; // 1 2 6 7

总结

本节简要介绍了 STL 中的数值算法以及集合的相关算法,由参数的型别可以看出,这些算法接受的都是最低的 InputIterator 以及 OutputIterator,并未对迭代器做更高的限制,并且需要特别注意的是,集合算法只能处理有序的区间。