树状数组

树状数组

一、什么是树状数组

树状数组,其英文是 Binary Indexed Tree(简称 BIT),也称为 二叉索引树/二叉下标树。

  • 树状数组虽然名称后缀是数组,但实际上是一棵由数组实现的树

最基本的树状数组支持 2 种操作:

  • 单点修改:更新数组 nums 中任意单个元素值
  • 区间查询:求数组 nums 中任意区间的元素和

并且在复杂度上能够满足:

  • 树状数组对于单点修改和区间查询的时间复杂度都是 O(logn)

比如说,数组 [3, 2, 1, 3],它对应的树状数组结构类似这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
 ___        ___        ___        ___
| 3 | | 5 | | 1 | | 9 | -- 树状数组

___
| 9 |
/ |
/ /|
/ / |
/ / |
/ / |
___ / |
| 5 | / |
/| / |
/ | / |
/ | / |
/ | / |
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 树状结构

树状数组是一种专门用于高效计算前缀和、区间和数据结构。

二、为什么用树状数组?

2.1 背景引入

事项说明:下面的所有数组区间设定不是 [0, n - 1],而是 [1, n]

一开始的问题是需要实现 2 种操作:

  • 单点修改:更新数组 nums 中任意单个元素值
  • 区间查询:求数组 nums 中任意区间的元素和

要同时实现这 2 种操作,有 2 种办法:

  • 普通数组:nums[i] 单点修改;sum(nums[i]) 区间查询
  • 前缀和数组:sum(ps[i]) 单点修改;ps[i] 区间查询

比如说,普通数组和前缀和数组的实现分别是:

(1) 普通数组实现

1
2
3
4
5
6
7
8
9
10
11
12
// 单点修改
void update(int index, int val) {
nums[index] += val;
}
// 区间查询
int query(int l, int r) {
int sum = 0;
for (int i = l; i <= r; i++) {
sum += nums[i];
}
return sum;
}

(2) 前缀和数组实现

1
2
3
4
5
6
7
8
9
10
11
12
13
// 单点修改
void update(int index, int val) {
for (int i = index; i < n; i++) {
ps[i] += val;
}
}
// 区间查询
int query(int l, int r) {
if (l == 0) {
return ps[r];
}
return ps[r] - ps[l - 1];
}

两种方式都能实现,然后它们的时间复杂度分别是:

时间复杂度 单点修改 区间查询
普通数组 O(1) O(n)
前缀和数组 O(n) O(1)

不管哪种实现方式,总有一种操作的时间复杂度为 O(n)

O(n) 还是比较慢的,还有其他办法可以优化吗?

2.2 问题分析

分析前,先看一下普通数组和前缀和数组的区间情况。

普通数组的每个元素都可以看成一个区间:

1
2
3
4
5
6
7
  ___        ___        ___        ___
|1,1| |2,2| |3,3| |4,4| -- 区间
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

这种情况下执行区间查询,需要将多个区间合并起来。

所以普通数组区间查询时间复杂度为 O(n) 的原因是:

  • 区间范围太小,区间较多,合并区间比较慢

前缀和数组的区间是连续元素区间:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  ____________________________________
|1, 4| -- 区间
_________________________
|1, 3| -- 区间
______________
|1, 2| -- 区间
___
|1,1| -- 区间
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 5 | | 6 | | 9 | -- 前缀和数组
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

这种情况下执行单点修改,多个区间会受到影响,需要同时修改。

所以,前缀和数组单点修改时间复杂度为 O(n) 的原因是:

  • 单个元素影响了太多区间,导致区间都要修改

要是能能解决这 2 个问题,理论上就能优化掉 O(n) 的复杂度。

2.3 区间划分

区间太离散,或者区间太连续,都会出现 O(n) 的时间复杂度。

那么是否可以局部建立区间?不至于太散,也不会很连续。

比如说相邻区间两两合并,层层叠加起来,最终形成一棵二叉树:

1
2
3
4
5
6
7
8
9
10
11
  _______________________________________
| 9 | -- 区间(0,4]
_______________ _______________
| 5 | | 4 | -- 区间(0,2]和(2,4]
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 区间(i-1,i]
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

这样的话,单点修改和区间查询的时间复杂度如何了?

(1) 单点修改

比如修改元素 2,改成 1,那么更新的区间(二叉树从下往上)包括:

  1. (1,2]2 -> 1
  2. (0,2]5 -> 4
  3. (0,4]9 -> 8

更新的区间数量就是树的高度 h,所以时间复杂度是 O(logn)

(2) 区间查询

比如查询区间 [2,3],那么它查询过程(二叉树从上往下)包括:

  1. (0,4]
  2. (0,2](2,4]
  3. (1,2](2,3]

查询的区间数量也是树的高度,时间复杂度也是 O(logn)

利用这种二叉树的区间划分结构,确实可以做到:

  • 单点修改和区间查询的时间复杂度都是 O(logn)

不过这实际上是利用空间换时间得到的,空间复杂度变成了 O(n)

2.4 结构优化

为了达到 O(logn) 的时间复杂度,需要耗费 O(n) 空间建一棵二叉树,有点麻烦。

还有没有其他办法可以优化?尝试一下二叉树剪枝?

  • 可以采用前缀和差的思路,来剪掉二叉树中一些不必要的节点

比如,区间 [2,3] 实际上是区间 (0,3] - (0,1] 得到的。

再看上面的二叉树结构:

1
2
3
4
5
6
7
8
9
10
11
  _______________________________________
| 9 | -- 区间(0,4]
_______________ _______________
| 5 | | 4 | -- 区间(0,2]和(2,4]
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 区间(i-1,i]
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

考虑利用前缀和差的推断方式,可以有:

  • 区间 (1,2] 可以由区间 (0,2] - (0,1] 得到
  • 区间 (3,4] 可以由区间 (2,4] - (2,3] 得到

那么区间 (1,2] 和 区间 (3,4] 这 2 个节点其实可以去掉:

1
2
3
4
5
6
7
8
9
10
11
  _______________________________________
| 9 | -- 区间(0,4]
_______________ _______________
| 5 | | 4 | -- 区间(0,2]和(2,4]
___ ___
| 3 | | 1 | -- 区间(i-1,i]
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

同理,(2,4] 可以由区间 (0,4] - (0,2] 得到:

1
2
3
4
5
6
7
8
9
10
11
  _______________________________________
| 9 | -- 区间(0,4]
_______________
| 5 | -- 区间(0,2]
___ ___
| 3 | | 1 | -- 区间(i-1,i]
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

删除掉一些不必要的区间节点后,二叉树的节点数量就少了很多。

此时再调整一下图形,改成这样子(5 和 9 的位置移动了一下):

1
2
3
4
5
6
7
8
9
10
11
  _______________________________________
| 9 | -- 区间(0,4]
_______________
| 5 | -- 区间(0,2]
___ ___
| 3 | | 1 | -- 区间(i-1,i]
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

再把区间节点值都收集起来,发现剩余节点,刚好能凑成一个大小为 n 的数组:

1
2
3
4
5
6
7
8
9
10
11
12
13
  _______________________________________
| 9 | -- 区间(0,4]
_______________
| 5 | -- 区间(0,2]
___ ___
| 3 | | 1 | -- 区间(i-1,i]
------------------------------------------------------------------
___ ___ ___ ___
| 3 | | 5 | | 1 | | 9 | -- 树状数组
___ ___ ___ ___
| 3 | | 2 | | 1 | | 3 | -- 原始数组

1 2 3 4 -- 区间索引

这个得到的数组 [3, 5, 1, 9] 就是树状数组。

  • 树状数组虽然是一个数组形式,但实际上表示的是一棵经过剪枝的二叉树

通过前缀和差剪枝的方式,将二叉树结构变成了数组,结构简单多了。

三、如何实现树状数组?

3.1 不依赖于数值

在上面的构建二叉树中,并没有用到数组数据,因此:

  • 树状结构不依赖于具体的数据,只和数据量有关

也就是说,不管数组内的数据如何,大小一样的数组建立起来的树状结构一定是一样的。

比如说,数组大小为 9 的树状结构肯定是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
 _______________________________
| 8 |
_______________
| 4 |
_______ _______
| 2 | | 6 |
___ ___ ___ ___ ___
| 1 | | 3 | | 5 | | 7 | | 9 | -- 树状结构

___ ___ ___ ___ ___ ___ ___ ___ ___
| | | | | | | | | | -- 原始数组

1 2 3 4 5 6 7 8 9 -- 索引下标

注意,这里面树状结构中的数字分别是树状数组中的下标,并不是值。

所以,现在只要知道树状数组中,每个下标的对应的是二叉树中的什么节点即可。

3.2 从左往右两两向上合并

仔细看树状结构,就会发现:

  • 树状结构的节点是从左往右,相邻节点两两向上合并

比如数组大小为 9,节点从左到右两两向上合并的效果就是:

1
2
3
4
              8       -- 3层
4 8 -- 2层
2 4 6 8 -- 1层
1 2 3 4 5 6 7 8 9 -- 0层

其中,

  • 节点 1 和 2 向上合并,得到父节点 2
  • 节点 3 和 4 向上合并,得到父节点 4
  • 节点 2 和 4 向上合并,得到父节点 4

树状结构就是这样从左往右,节点两两向上合并建起来的。

3.3 单点修改

树状结构的单点修改,需要从下往上更新树节点,所以要知道子节点对应的父节点。

为了方便分析父节点的位置,首先给出一个大小为 18 的树状数组:

1
2
3
4
5
                                    16           -- 4层
8 16 -- 3层
4 8 12 16 -- 2层
2 4 6 8 10 12 14 16 18 -- 1层
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 -- 0层

其中,里面的数字表示的是下标(注意不是值):

  • 2 是 1 父节点
  • 4 是 2、3 的父节点
  • 8 是 4、6、7 的父节点
  • 16 是 8、12、14、15 的父节点

观察每一层的节点,可以看到:

  • 上一层都是从下一层节点中,按 2 取 1 的方式获取

因此,可以知道每层的节点规律:

  • 0层:起始下标 2^0,间隔 2^0
  • 1层:起始下标 2^1,间隔 2^1
  • 2层:起始下标 2^2,间隔 2^2
  • 3层:起始下标 2^3,间隔 2^3
  • k层:起始下标 2^k,间隔 2^k

因为是每隔 1 个节点作为上一层节点,所以可以总结出:

  • 父节点下标 = 等于当前节点下标 + 节点所在层级的间隔

所以找父节点的下标可以这样做:

  • 找到当前节点所在层级 k
  • 父节点下标 = 当前节点的下标 + 2^k

比如说:

  • 5 所在层级是 0,父节点下标 = 5 + 2^0 = 6
  • 12 所在层级是 2,父节点下标 = 12 + 2^2 = 16

不过这里面的当前节点所在层级如何计算?

  • 节点层级,等于当前索引值的因数中 2 的数量

比如说:

  • 5 = 5,2 的数量有 0 个,所以层级是 0
  • 12 = 2*2*3,2 的数量有 2 个,所以是层级 2

这样就能算出父节点的下标了,就可以不断往上更新父节点了。

所以单点修改的代码最终是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// 单点修改
void update(int index, int val) {
if (index < 0 || index >= n) {
return;
}

// 更新当前节点
tree[index] += val;
// 父节点下标
int p = nextP(index);
// 往上更新父节点
update(p, val);
}

// 计算父节点的下标
int nextP(int index) {
// 当前层级
int k = level(index);
// 层级间隔 2^k
int g = gap(k);
// 父节点下标
int p = index + g;
return p;
}

// 计算当前层级
int level(int index) {
int k = 0;
while (index % 2 == 0) {
index >>= 1;
k++;
}
return l;
}

// 计算层级间隔
int gap(int k) {
int g = 1;
while (k > 0) {
gap <<= 1;
k--;
}
return g;
}

树状结构的节点取的很有规律,所以很容易算出父节点的位置。

3.4 区间查询

首先说明,区间 [l, r] 实际上可以通过前缀和差来计算:

  • [l, r] = (0, r] - (0, l - 1]

所以只要计算出 (0, r](0, l - 1] 的值,就能知道任意区间和了。

还是大小 18 的树状数组为例:

1
2
3
4
5
                                    16           -- 4层
8 16 -- 3层
4 8 12 16 -- 2层
2 4 6 8 10 12 14 16 18 -- 1层
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 -- 0层

在树状结构定义中,可以知道:

  • 高层级节点 = 它之前的所有低层级节点之和

比如说:

  • 8 表示的是 (0, 8] 的区间和
  • 12 表示的是 (8, 12] 的区间和,因为 8 层级比 12 大

想要计算某个节点的区间和,可以通过前面的高层节点快速得到结果。

比如说,要寻找 (0, 14] 的区间和,按照层级可以分为几部分:

  • (0, 8]
  • (8, 12]
  • (12, 14]

这几个区间加起来,实际上就等于 (0,14] 的区间和了。

  • 只要不断往前找高层级节点,累加起来,就能得到区间和

怎么往前找高层级节点呢?借鉴单点修改的过程:

  • 找到当前节点所在层级 k
  • 前面的高层节点下标 = 当前节点的下标 - 2^k

所以区间查询的代码最终是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 区间查询
int query(int l, int r) {
return query(r) - query(l - 1);
}

// 区间和 (0, index]
void query(int index) {
if (index < 0 || index >= n) {
return 0;
}

int sum = 0;
// 当前节点值
sum += tree[index];
tree[index] += val;
// 前一个高层节点
int p = prevP(index);
sum += query(p);
return sum;
}

// 计算前一个高层节点的下标
int prevP(int index) {
// 当前层级
int k = level(index);
// 层级间隔 2^k
int g = gap(k);
// 前一个高层节点下标
int p = index - g;
return p;
}

3.5 lowbit 的妙用

在前面的处理中,单点更新和查询区间的过程差不多:

  • 单点更新,从下往上,从左至右,寻找高层级节点
  • 区间查询,从下往上,从右至左,寻找高层级节点

单点修改和区间查询都是通过当前层级和层级间隔来找的下一个节点。

然而层级和相邻的高层节点计算有点麻烦,有没有办法优化?

(1) 层级计算

1
2
3
4
5
6
7
8
9
// 计算当前层级
int level(int index) {
int k = 0;
while (index % 2 == 0) {
index >>= 1;
k++;
}
return l;
}

计算过程就是不断去判断是否能对 2 整除,实际上:

  • 计算层级的过程,等于在找 index 的二进制表示中末尾的 0 的数量

比如说,12 的二进制表示是 1100,对它不断整除 2:

1
2
3
1100
110
11

所以,不需要整除 2,而是:

  • 只要知道二进制值末尾的 0 数量,就能直接知道层级了

怎么计算二进制值末尾的 0 数量?

可以通过位运算来处理:

1
2
3
4
5
6
7
8
9
10
// 计算当前层级
int level(int index) {
int l = 0;
int m = index & (-index);
while (m > 1) {
m >>= 1;
l++;
}
return l;
}

其中,index & (-index) 可以保留二进制值得最低位 1。

比如 12 的就是 1100 & 0100 = 0100

(2) 层级间隔

1
2
3
4
5
6
7
8
9
// 计算层级间隔
int gap(int k) {
int g = 1;
while (k > 0) {
gap <<= 1;
k--;
}
return g;
}

实际上,计算层级间隔的过程就是计算层级的反过程:

  • 计算层级,一直整除 2
  • 计算间隔,一直乘以 2

如果把 level(index)gap(k) 合并起来 gap(index)

就会发现,两者的位移动会互相抵消掉,最终:

1
2
3
4
// 计算层级间隔
int gap(int index) {
return index & (-index);
}

层级间隔,就等于下标的 index & (-index) 位运算结果。

所以最终获取相邻高层级节点的代码就变成这样了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 计算后一个高层节点的下标
int nextP(int index) {
// 层级间隔 2^k
int g = gap(index);
// 父节点下标
return index + g;
}


// 计算前一个高层节点的下标
int prevP(int index) {
// 层级间隔 2^k
int g = gap(index);
// 前一个高层节点下标
return index - g;
}

// 计算层级间隔
int gap(int index) {
return index & (-index);
}

在树状数组中,gap(index) 方法一般用 lowBit(index) 表示。

因为 lowBit(index) 刚好取的是 index 的最低位 1 的结果。

总结

  • 树状数组的存储形式是数组,但是实际表示的结构是一棵树
  • 树状数组是为了优化 O(n) 复杂度,提出的一种利用空间换时间的树结构
  • 树状数组主要用于 2 方面操作:
    • 单点修改:时间复杂度 O(logn)
    • 区间查询:时间复杂度 O(logn)
  • 树状数组是利用前缀和差的思想来对树进行剪枝,剪枝后才能保存到数组中
  • 查找前后高层节点时,采用的是一种 lowBit 方法

参考

https://zhuanlan.zhihu.com/p/93795692

https://blog.csdn.net/qq_40941722/article/details/104406126

https://www.jianshu.com/p/7cd5ad2f449a

https://leetcode.cn/circle/discuss/qGREiN/

https://oi-wiki.org/ds/fenwick/

附录

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/**
* 树状数组(二叉索引树)
*
* @author weijiaduo
* @since 2022/10/1
*/
public class BinaryIndexTree {

/**
* 树状数组
*/
int[] tree;
/**
* 数组大小
*/
int n;

public BinaryIndexTree(int n) {
this.n = n;
tree = new int[this.n];
}

/**
* 单点修改
*
* @param index 指定下标
* @param val 增量修改值
*/
public void update(int index, int val) {
while (index < n) {
tree[index] += val;
index = nextP(index);
}
}

/**
* 区间查询
*
* @param l [l, r]
* @param r [l, r]
* @return [l, r] 的区间和
*/
public int query(int l, int r) {
return query(r) - query(l - 1);
}

/**
* 区间查询
*
* @param index (0, index]
* @return (0, 1] 的区间和
*/
private int query(int index) {
int sum = 0;
while (index > 0) {
sum += tree[index];
index = prevP(index);
}
return sum;
}

/**
* 计算后一个高层节点的下标
*
* @param index 当前下标
* @return 后一个高层节点的下标
*/
int nextP(int index) {
// 层级间隔 2^k
int g = lowBit(index);
// 父节点下标
return index + g;
}

/**
* 计算前一个高层节点的下标
*
* @param index 当前下标
* @return 前一个高层节点的下标
*/
int prevP(int index) {
// 层级间隔 2^k
int g = lowBit(index);
// 前一个高层节点下标
return index - g;
}

/**
* index 只剩余最低位 1 的值
*
* @param index 值
* @return 只剩余最低位 1 的值
*/
private int lowBit(int index) {
return index & (-index);
}

}
作者

jiaduo

发布于

2022-10-01

更新于

2023-04-03

许可协议