树状数组基础、二维树状数组
前言
树状数组真的是一种非常非常巧妙的数据结构,代码短小精炼,而且非常灵活,有树状数组套主席树、树上莫队等非常神奇的操作!
- 树状数组:
http://www.hawstein.com/posts/binary-indexed-trees.html
https://www.cnblogs.com/hsd-/p/6139376.html
http://www.cppblog.com/menjitianya/archive/2015/11/02/212171.html
Stars - POJ - 2352
题意
在二维平面上我们认为一颗星的等级为:不比它高且不在它右边的星的数量,给定坐标(按y坐标升序给出),求每个等级的星的数量
思路
树状数组模板题
因为y是按升序输入的,所以我们只需要考虑有多少点的x坐标小于等于当前坐标即可
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 4e4 + 15;
const int inf = 0x3f3f3f3f;
int tree[N], cnt[N];
inline void init() {
memset(tree, 0, sizeof(tree));
memset(cnt, 0, sizeof(cnt));
}
//lowbit: 获取最末位的1,表示为0000...1...00..的形式
inline int lowbit(int idx){ return (idx & -idx); }
//根据性质,不断减去末尾的1求和
int getSum(int idx){
int sum;
for(sum = 0; idx > 0; idx -= lowbit(idx)){
sum += tree[idx];
}
return sum;
}
//根据定义,不断加上末尾的1来进位以更新
void update(int idx, int val){
while(idx < N){
tree[idx] += val;
idx += (idx & -idx);
}
}
int main(){
int n;
while(~scanf("%d", &n)){
init();
for(int i = 1; i <= n; i++){
int a, b;
scanf("%d%d", &a, &b);
a++;
cnt[getSum(a) + 1]++;
update(a, 1);
}
for(int i = 1; i <= n; i++){
printf("%d\n", cnt[i] );
}
}
}
Cows - POJ - 2481
题意
对于每一个区间,求有多少个区间能完全包含该区间(完全相同不算)
思路
排个序就是上一题了,注意去重
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 3e5 + 15;
const int inf = 0x3f3f3f3f;
struct node{
int s, e, id;
bool operator < (const node& b) const{
if(e != b.e) return e > b.e;
else return s < b.s;
}
};
int ans[N];
int tree[N];
node a[N];
inline void init() { memset(tree, 0, sizeof(tree)); }
inline int lowbit(int idx){ return (idx & -idx); }
inline int getSum(int idx){
int sum = 0;
for(int x = idx; x > 0; x -= lowbit(x)){
sum += tree[x];
}
return sum;
}
inline void update(int idx, int val){
for(int x = idx; x < N; x += lowbit(x)){
tree[x] += val;
}
}
int main(){
int n;
while(scanf("%d", &n) && n){
init();
for(int i = 1; i <= n; i++){
a[i].id = i;
scanf("%d%d", &a[i].s, &a[i].e);
a[i].s++, a[i].e++;
}
sort(a + 1, a + n + 1);
int pre_l = -1, pre_r = -1, cnt = 0;
for(int i = 1; i <= n; i++){
if(pre_l == a[i].s && pre_r == a[i].e){
cnt++;
}else{
cnt = 0;
pre_l = a[i].s, pre_r = a[i].e;
}
ans[a[i].id] = getSum(a[i].s) - cnt;
update(a[i].s, 1);
}
for(int i = 1; i <= n; i++){
printf("%d", ans[i]);
if(i < n) putchar(' ');
}
puts("");
}
}
Japan - POJ - 3067
题意
两条线上各有N、M个点,现在给定q个操作,将编号为a和b的分别位于两条线上的点相连,问最终线与线间会产生多少个交点
思路
全部l、r降序排序,再用树状数组查询r前面有多少个点,即可知道插入这条线后会产生多少个交点
注意开long long,本题坑的地方
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
const int N = 1e6 + 15;
const int inf = 0x3f3f3f3f;
struct node{
int s, e;
bool operator < (const node& b) const{
if(s != b.s) return s > b.s;
else return e > b.e;
}
};
ll tree[N];
node a[N];
inline void init() { memset(tree, 0, sizeof(tree)); }
inline int lowbit(int idx){ return (idx & -idx); }
inline ll getSum(int idx){
ll sum = 0;
for(int x = idx; x > 0; x -= lowbit(x)){
sum += tree[x];
}
return sum;
}
inline void update(int idx, ll val){
for(int x = idx; x < N; x += lowbit(x)){
tree[x] += val;
}
}
int main(){
int t, csn = 1;
scanf("%d", &t);
while(t--){
init();
int n, m, k;
scanf("%d%d%d", &n, &m, &k);
for(int i = 1; i <= k; i++){
scanf("%d%d", &a[i].s, &a[i].e);
}
sort(a + 1, a + k + 1);
ll ans = 0;
for(int i = 1; i <= k; i++){
ans += getSum(a[i].e - 1);
update(a[i].e, 1);
}
printf("Test case %d: %lld\n", csn++, ans);
}
}
Mobile phones - POJ - 1195
题意
懒得翻译 = =
思路
二维树状数组模板题
(二维树状数组不就是树状数组套树状数组)
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 1024 + 15;
const int inf = 0x3f3f3f3f;
int tree[N][N];
inline void init() { memset(tree[0], 0, sizeof(tree));}
inline int lowbit(int idx){ return (idx & -idx); }
inline int getSum(int idx, int idy){
int sum = 0;
for(int x = idx; x > 0; x -= lowbit(x)){
for(int y = idy; y > 0; y -= lowbit(y)){
sum += tree[x][y];
}
}
return sum;
}
inline void update(int idx, int idy, int val){
for(int x = idx; x < N; x += lowbit(x)){
for(int y = idy; y < N; y += lowbit(y)){
tree[x][y] += val;
}
}
}
int main(){
init();
int op, n;
while(~scanf("%d", &op)){
if(op == 0){
scanf("%d", &n);
}else if(op == 1){
int x, y, val;
scanf("%d%d%d", &x, &y, &val);
x++, y++;
update(x, y, val);
}else if(op == 2){
int l, b, r, t;
scanf("%d%d%d%d", &l, &b, &r, &t);
l++, b++, r++, t++;
printf("%d\n", getSum(r, t) - getSum(r, b - 1) - getSum(l - 1, t) + getSum(l - 1, b - 1));
}else if(op == 3){
init();
}
}
}