树状数组和线段树

树状数组

解决什么问题?

  • 给某个位置上的数,加上一个数。(单点修改)
  • 求某一个前缀和。(区间查询)

时间复杂度

  • O(long n)

步骤及实现

给定一个数组{1,5,7,11,4,3,2,6,11,21,15,17,14,13,20,35}。下面将通过树状数组的方式进行查询与修改。

  1. 把普通数组进行存储。

    存储的位置,从数组的下标”1”开始存,下标”0”不存储元素。

  2. 把普通数组转换成树状数组(重点)。

    看不懂没关系,先给大家讲解一下,单点修改和区间查询是如何方便的:

    • 树状数组的单点修改,对整体数组修改影响较小
    “前缀和”的单点修改对整体数组修改的影响

    假如我们对一个数组的下标3的元素进行加3操作,那么它对应的前缀和数组,下标3之后的所有元素都要进行加3操作,这样操作的时间开销是极大的。

    如果使用树状数组,对下标3的元素进行加3操作,树状数组的开销就小得多。

    仅对下标为“3”,”4”,”8 “,“16”的元素进行加3操作。

    1. 下面是最终转换的树状数组:

    1. 构建树状数组的代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/****************************************************
先给出一个n,表示数组元素的个数,然后给出n个元素并构建出树状数组,并输出打印树状数组
输入:
16
1 5 7 11 4 3 2 6 11 21 15 17 14 13 20 35
输出:
1 6 7 24 4 7 2 39 11 32 15 64 14 27 20 185
*****************************************************/
#include<iostream>
using namespace std;
int n;
const int N=100005;
int a[N],tr[N];
int lowbit(int x)
{
return x & -x;
}
void add(int x,int v)
{
for(int i=x;i<=n;i+=lowbit(i))
{
tr[i]+=v;
}
}
int query(int x)
{
int res=0;
for(int i=x;i;i-=lowbit(i))
{
res+=tr[i];
}
return res;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)//输入
{
scanf("%d",&a[i]);
}

for(int i=1;i<=n;i++) //构建过程
{
add(i,a[i]); //给第i个元素,加上a[i]
}


for(int i=1;i<=n;i++)//输出树状数组
{
printf("%d ",tr[i]);
}

return 0;
}

  1. 在某个位置加上一个数,并且查询某个某个区间的前缀和。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/****************************************************
先给出一个n,表示数组元素的个数,下面给出n个元素并构建出树状数组。
然后给出一个m,表示要操作的次数,下面每一行都表示一次操作,共m行
每次操作都包含三个数k,x,y。当k=1时,表示在某个位置x加上y,当k=2时,输出区间[x,y]的和
输入:
16
1 5 7 11 4 3 2 6 11 21 15 17 14 13 20 35
5
1 3 3
1 4 3
1 8 3
1 16 3
2 4 10
输出:
64

解释:
从下标4到下标10的和,其中下标4和下标10已经加3,故下标4到下标10的和为: (11+3) + 4 + 3 +2 + 6 + 11 + (21+3) = 64
*****************************************************/
#include<iostream>
using namespace std;
int n,m;
const int N=100005;
int a[N],tr[N];
int lowbit(int x)
{
return x & -x;
}
void add(int x,int v)
{
for(int i=x;i<=n;i+=lowbit(i))
{
tr[i]+=v;
}
}
int query(int x)
{
int res=0;
for(int i=x;i;i-=lowbit(i))
{
res+=tr[i];
}
return res;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)//输入
{
scanf("%d",&a[i]);
}

for(int i=1;i<=n;i++) //构建过程
{
add(i,a[i]); //给第i个元素,加上a[i]
}

scanf("%d",&m);
while(m--)
{
int k,x,y;
scanf("%d%d%d",&k,&x,&y);
if(k==1) //添加
{
add(x,y);
}else //查询
{
printf("%d\n",query(y)-query(x-1));
}
}

return 0;
}

线段树

解决什么问题?

  • 单点修改
  • 区间查询

时间复杂度

  • 修改:O(long n)
  • 查询:O(4long n)

步骤及实现

给定一个数组{1,2,3,4,5,6,7}。下面将通过线段树的方式进行查询与修改。

  1. 把普通数组进行存储。

存储的位置,从数组的下标”1”开始存,下标”0”不存储元素。

  1. 把普通数组转换成线段树(重点)。

  1. 线段树代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/****************************************
先给出一个n,表示数组元素的个数,下面给出n个元素并构建出线段树。
然后给出一个m,表示要操作的次数,下面每一行都表示一次操作,共m行
每次操作都包含三个数k,x,y。当k=1时,表示在某个位置x加上y,当k=2时,输出区间[x,y]的和
输入:
7
1 2 3 4 5 6 7
3
1 3 3
1 4 3
2 4 7
输出:
25
解释: (4+3)+5+6+7=25
****************************************/
#include<iostream>
using namespace std;
int n,m;
const int N=100005;
int w[N];
struct Node //每个结点的数据类型
{
int l,r;
int sum;
}tr[N*4]; //线段树
void pushup(int u) //通过两个子结点,来计算父结点的值
{
tr[u].sum = tr[u<<1].sum+tr[u<<1|1].sum; // u<<1相当于2*u,u<<1|1相当于2*u+1
}

void build(int u,int l,int r) //构建线段树函数
{
if(l==r) // 叶结点
{
tr[u]={l,r,w[r]}; //初始化叶结点
}else
{
tr[u]={l,r};
int mid = l+r>>1;
build(u<<1,l,mid); //构建每个结点
build(u<<1|1,mid+1,r); //构建每个结点
pushup(u); //计算每个结点
}
}
int query(int u,int l,int r)
{
if(tr[u].l>=l &&tr[u].r<=r) //如果当前节点所表示的区间完全在查询区间 [l, r] 内,直接返回该节点的区间和 tr[u].sum。
{
return tr[u].sum;
}
int mid = tr[u].l+tr[u].r>>1;
int sum=0;
if(l<=mid)
{
sum=query(u<<1,l,r);//递归查询左子区间
}
if(r>mid)
{
sum+=query(u<<1|1,l,r);//递归查询右子区间
}
return sum;

}
void modify(int u,int x,int v)
{
if(tr[u].l==tr[u].r)
{
tr[u].sum+=v;// 叶子节点
}else
{
int mid = tr[u].l+tr[u].r>>1;
if(x<=mid)
{
modify(u << 1, x, v); // 修改左子树
}else
{
modify(u << 1 | 1, x, v);// 修改右子树
}
pushup(u);// 更新父结点
}

}

int main()
{
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>w[i];
}
build(1,1,n);
cin>>m;
while(m--)
{
int k,x,y;
cin>>k>>x>>y;
if(k == 1)
{
modify(1,x,y);
}else
{
cout<<query(1,x,y)<<endl;
}

}
}