二分查找
考虑以下查找问题:
输入: 个数的数组 和一个值
输出: 下标使得, 或者当不在出现时, 返回特殊值
上面这个问题是我们在插入排序那一节就已经介绍的.
注意到, 如果序列A已排好序, 就可以将该序列的中点与v进行比较. 根据比较的结果, 原序列中有一半就可以不用再做进一步的考虑了, 这种在 有序数组中高效查找特定元素的算法 被称为 二分查找(Binary Search).
BINARY-SEARCH
#![allow(unused)]
fn main() {
use std::cmp::Ordering;
pub fn binary_search<T: Ord>(arr: &[T], v: &T) -> Option<usize> {
let mut low = 0;
let mut high = arr.len().checked_sub(1)?; // 处理空数组
while low <= high {
let mid = low + (high - low) / 2;
match arr[mid].cmp(v) {
Ordering::Equal => return Some(mid),
Ordering::Less => low = mid + 1,
Ordering::Greater => high = mid - 1,
}
}
None
}
}
Note
在这里
checked_sub用于检查整数减法, 计算self - rhs, 如果发生溢出则返回None. 并使用Ordering方便用match进行匹配.
首先, 我们要确定区间[low, high], low为数组首索引(在Rust中为0), high为数组尾索引(在Rust中为arr.len() - 1).
然后, 在区间内找到中间值mid = (low + high) / 2, 这可能会溢出, 利用简单数学证明可得:
该式可以避免溢出. (事实上, 标准库的实现也是这么做的)
利用条件low <= high可以确保区间[low, high]成立.
接着, 判定arr[mid]与v的关系. 当二者相等时(Ordering::Equal), 直接返回当前索引;如果较大说明右半部分不存在, 通过收紧(high = mid - 1), 将右边部分排除, 同理较小时用low = mid + 1将左半部分排除.
二分查找显然使用了分治法1, 不过比较特殊的是它没有合并的步骤(并且采用迭代法). 我们利用循环不变式来证明二分查找的正确性:
在每次循环迭代开始时, 如果目标值
v存在于数组中, 则v一定位于子数组arr[low..high-1]中, 其中arr是已排序的数组,low和high是当前搜索区间的边界索引.
这个循环不变式比较简单, 留给读者自证. 从上面的过程, 我们不难证得.
我们增高难度, 在现有的基础上, 我们不再保证有且仅有一个i, 即A数组中可能存在多个v, 要求给出最小的i. 这是一个常见的求左边界的二分算法, 我们用Rust实现:
#![allow(unused)]
fn main() {
pub fn find_left_boundary<T: Ord>(arr: &[T], v: &T) -> Option<usize> {
let mut low = 0;
let mut high = arr.len().checked_sub(1)?;
while low < high { // 取消等于
let mid = low + (high - low) / 2;
if arr[mid] < *v { // 注意解引用
low = mid + 1; // 与之前相同
} else {
high = mid; // 目标值在左侧或当前位置
}
}
if arr.get(low)? == v { // 使用get更安全
Some(low)
} else {
None
}
}
}
这里主要逻辑改写在比较部分: 当arr[mid] < v时, 目标值必然在右侧, 所以移动low;当arr[mid] >= v目标值可能出现在左侧或当前位置(arr[mid] == v), 所以移动high到mid(不是mid - 1). 重点在low >= high说明所有的v都已出现(这里就是上面low < high不使用等号的原因), 那么low必然在最小v的位置上. 当然这个算法有个问题, 不存在时会误报, 所以要二次判断.
同理, 不难写出寻找最大i的二分算法:
#![allow(unused)]
fn main() {
pub fn find_right_boundary<T: Ord>(arr: &[T], v: &T) -> Option<usize> {
let mut low = 0;
let mut high = arr.len().checked_sub(1)?;
while low < high {
let mid = low + (high - low + 1) / 2; // 向上取整
if arr[mid] <= *v {
low = mid; // 保留当前位置
} else {
high = mid - 1;
}
}
if arr.get(low)? == v {
Some(low)
} else {
None
}
}
}
可以发现, 比较逻辑的改写比较复杂, 在比较微小的地方出错就会导致算法进入死循环或错估, 所以循环不变式在判断二分算法的正确性上非常重要.
在向我们刚刚的数组切片或有序数组中, std标准库提供了一系列方法:
#![allow(unused)]
fn main() {
let v = vec![1, 3, 5, 7, 9];
assert_eq!(v.binary_search(&5), Ok(2));
assert_eq!(v.binary_search(&4), Err(2)); // 插入后为 [1, 3, 4, 5, 7, 9]
}
上面这个例子中binary_search返回Result<usize, usize>, Ok(index) 中 index 为元素所在位置, Err(index) 中则为未找到元素时, 如果将元素插入到数组, 保持有序的位置. binary_search_by允许通过函数来设置查找规则, binary_search_by_key允许通过键(如结构体字段)查找.
Note
上面的这些方法和实现都要确保数组已经排序, 否则返回的结果无意义.
binary_search_by之类的, 通常来说与上面的手写性能相差不大, 但更具有扩展性.
在较新的版本(Rust 1.52+)中, partition_point 可以用来返回满足条件的第一个元素的位置:
#![allow(unused)]
fn main() {
let v = vec![1, 2, 2, 3, 3, 4, 5];
println!("{}", v.partition_point(|&x| x < 4)); // 第一个不小于4的元素位置
}
Note
该函数底层是
self.binary_search_by(|x| if pred(x) { Less } else { Greater }).unwrap_or_else(|i| i)这种写法使得binary_search永远找不到等于的位置, 所以就会返回插入之后仍然有序的位置, 也就是第一个不满足pred函数的位置. (其中pred是调用者的输入)
我们重新来看插入排序:
pub fn insert_sort<T: Ord>(arr: &mut [T]) {
for i in 1..arr.len() {
let mut j = i;
while j > 0 && arr[j] < arr[j - 1] {
arr.swap(j, j - 1);
j -= 1;
}
}
}
其中
while j > 0 && arr[j] < arr[j - 1] {
arr.swap(j, j - 1);
j -= 1;
}
这个部分是将需要排序的元素在有序数组中移动找到适合的位置, 在有序数组中查找, 完全可以利用二分:
#![allow(unused)]
fn main() {
pub fn insert_sort_by_binary_search<T: Ord>(arr: &mut [T]) {
for i in 1..arr.len() {
// 此处直接使用标准库, 下面代码完全可以用`partition_point`进行修改, 取决于您
let pos = arr[..i].binary_search(&arr[i]).unwrap_or_else(|pos| pos);
// 将整个区间右移1个单位(直接使用`swap`理论上也完全可以)
arr[pos..=i].rotate_right(1);
}
}
}
上面这个优化仅仅改变了比较次数, 不影响总体的时间复杂度. 对于大规模数据和操作代价较高的会有一定优化, 其他情况下还是使用线性更优.
两数之和
考虑下面这道题:
输入: 个整数的数组 和一个整数
输出: 下标 和 使得 , 或者当数组中无两数之和为 时, 返回特殊值
Tip
保证有且仅有一个满足条件的解
我们将用增量法, 分治法和一特殊方法解决此问题.
增量法最简单, 我们可以用两层循环:
#![allow(unused)]
fn main() {
pub fn incremental(arr: &[i32], x: i32) -> Option<(usize, usize)> {
for i in 0..arr.len() {
for j in (i + 1)..arr.len() {
if arr[i] + arr[j] == x {
return Some((i, j));
}
}
}
None
}
}
但复杂度将达到.
分治法则不难想到二分查找:
#![allow(unused)]
fn main() {
pub fn divide_and_conquer(arr: &[i32], x: i32) -> Option<(usize, usize)> {
// 先对数组进行排序 (保留原索引)
let mut sorted: Vec<(usize, &i32)> = arr.iter().enumerate().collect();
sorted.sort_by(|a, b| a.1.cmp(b.1));
let mut left = 0;
let mut right = sorted.len() - 1;
while left < right {
let sum = sorted[left].1 + sorted[right].1;
if sum == x {
// 确保返回的索引顺序与原数组一致
let (i1, i2) = (sorted[left].0, sorted[right].0);
return Some((i1.min(i2), i1.max(i2)));
} else if sum < x {
left += 1;
} else {
right -= 1;
}
}
None
}
}
该方案虽然使用了排序, 但按照归并排序的时间复杂度(标准库的排序实现更为复杂, 这里简单的以归并排序为例), 该实现仍然是.
最后一种方法非常特殊, 在后面的课程中我们会具体讲到, 这里简单介绍一下:
哈希表(Hash Map 或 Hash Table) 是一种通过 哈希函数(Hash Function) 将 键(key) 映射到存储位置的数据结构. 具体来说, 哈希函数将任意类型的键转换为一个固定范围内的整数(称为哈希值), 该值作为索引用于在数组中存储对应的值. 借助数组的随机访问特性, 哈希表在理想情况下的查找、插入和删除操作的时间复杂度均为 .
#![allow(unused)]
fn main() {
// 使用标准库实现的哈希表
use std::collections::HashMap;
pub fn search_by_hash_map(arr: &[i32], x: i32) -> Option<(usize, usize)> {
let mut map = HashMap::new();
for (i, &num) in arr.iter().enumerate() {
// 计算与当前项相加等于x的值
let complement = x - num;
// 在哈希表中寻找是否有complement
if let Some(&j) = map.get(&complement) {
// 如果有, 直接返回当前索引和complement所在索引
return Some((j, i));
}
// 否则, 保存当前项对应的索引
map.insert(num, i);
}
None
}
}
该算法是单循环的, 又因为哈希表的读取运算是常数时间, 所以这个实现时间复杂度为.
练习与回答
- 我们作以下考虑: 虽然归并排序的最坏情况运行时间为, 而插入排序的最坏情况运行时间为, 但是插入排序中的常量因子可能使得它在较小时, 在许多机器上实际运行得更快. 因此, 在归并排序中当子问题变得足够小时, 采用插入排序来使递归的叶变粗是有意义的. 考虑对归并排序的一种修改, 使用插入排序来排序长度为的个子表, 然后使用标准的合并机制来合并这些子表, 这里是一个待定的值.
#![allow(unused)]
fn main() {
pub fn merge_sort_by_insert(arr: &mut [i32], k: usize) {
let n = arr.len();
if n <= k {
insertion_sort(arr);
return;
}
let mid = n / 2;
merge_sort_by_insert(&mut arr[..mid], k);
merge_sort_by_insert(&mut arr[mid..], k);
merge(arr, mid);
}
fn insertion_sort(arr: &mut [i32]) {
for i in 1..arr.len() {
let mut j = i;
while j > 0 && arr[j - 1] > arr[j] {
arr.swap(j - 1, j);
j -= 1;
}
}
}
fn merge(arr: &mut [i32], mid: usize) {
let mut temp = arr.to_vec();
let (mut i, mut j, mut k) = (0, mid, 0);
while i < mid && j < arr.len() {
if temp[i] <= temp[j] {
arr[k] = temp[i];
i += 1;
} else {
arr[k] = temp[j];
j += 1;
}
k += 1;
}
while i < mid {
arr[k] = temp[i];
i += 1;
k += 1;
}
while j < arr.len() {
arr[k] = temp[j];
j += 1;
k += 1;
}
}
}
这种实现与 Timsort2 的核心思想一致, 结合了归并与插入的优点, 接下来我们深入讨论:
在上述算法中, 插入排序可以在时间内排序每个长度为的个子表(每个子表排序的时间复杂度为, 故). 同样不难看出, 合并子表的时间复杂度是, 综上就有该算法最坏情况下的时间复杂度为. 为了确保任何的取值不能使修改后算法时间复杂度高于原算法, 3.
在实际使用中, 不能随着的增长增长过快, 所以通常使用常数或对数级, 一般选用20至100的常数, 或者利用的对数情况动态调整. 更现实化的时候, 可以通过cpu缓存规则或者混合策略(分段使用常数和对数)来解决.