代码 / OI

线段树

24 分钟阅读
代码OI算法线段树

线段树是一种树状数据结构,它可以区间加减,区间乘除等一系列操作,用于处理那种可以合并状态的数据,在使用其3倍左右的空间的代价下使得其修改、查询、求区间和等等操作变得更加快捷。但与此同时,我们无法利用它处理类似于区间最长01序列此类问题,而且线段树代码冗长,其实很容易写错(也可能是因为我太菜了)。
我们将一组数据进行如下处理,每相邻的两个数据有一个父亲节点来记录其总的状态,然后再记录其相邻父节点的总的状态,以此类推,最终得到一个树状结构,我们从上到下依次编号1-n,这棵树满足父节点*2=左节点,父节点*2+1=右节点,设每个父节点代表l-r区间的状态,则左区间为l,(r+l)/2 , 右区间为(r+l)/2+1,r。根据此性质我们可以对他们进行维护。
每当我们访问一个节点,我们保证此节点的值一定正确,并尽可能少的改变其子孙节点的值,让时间消耗尽可能的小,同时把lazy标记也就是本来应该加的数传递到下一节点。
第一颗树实现了区间加与查询,第二颗树实现了区间乘法,加法,判断其中的先后顺序,其实也大同小异。
第三颗树用于实现历史最大值这种操作,然而由于本人电脑跑不动500mb的程序,再加之修改起来有点麻烦,就写个大致正确的程序摆在这了。
如果要继续完善,那么需要记录次大值并对于spread函数进行修改,就这样吧,后面再来补。

#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<queue>
#include<vector>

using namespace std ; 
typedef long long LL ; 
const int maxn = 500005 ; 
struct L {
    LL val , add ; 
} t[maxn] ; 
LL n , m , a[maxn] ; 

void build ( int p , int l , int r ) {
    if ( l == r ) { t[p].val = a[l] ; return ; } 
    int mid = ( l + r ) >> 1 ; 
    build ( p << 1 , l , mid ) ; 
    build ( p << 1 | 1 , mid + 1 , r ) ; 
    t[p].val = t[p<<1].val + t[p<<1|1].val ; 
    return ; 
}

void spread ( int p , int l , int r ) {
    if ( t[p].add ) {
        int mid = ( l + r ) >> 1 ; 
        t[p<<1].val = ( t[p<<1].val + t[p].add * ( mid - l + 1 ) ) , t[p<<1].add += t[p].add ; 
        t[p<<1|1].val = ( t[p<<1|1].val + t[p].add * ( r - mid ) ) , t[p<<1|1].add += t[p].add ; 
        t[p].add = 0 ; 
    }
}

void change ( int p , int l , int r , int x , int y , int z ) {
   if ( x <= l && r <= y ) { t[p].val += z * ( r - l + 1 ) ; t[p].add += z ; return ; }
   int mid = ( l + r ) >> 1 ; spread ( p , l , r ) ; 
   if ( x <= mid ) change ( p << 1 , l , mid , x , y , z ) ; 
   if ( y > mid ) change ( p << 1 | 1 , mid + 1 , r , x , y , z ) ; 
   t[p].val = t[p<<1].val + t[p<<1|1].val ; 
}

LL ask ( int p , int l , int r , int x , int y ) { 
    if ( x <= l && r <= y ) { return t[p].val ; } 
    int mid = ( l + r ) >> 1 ; spread ( p , l , r ) ; LL ans = 0 ; 
    if ( x <= mid ) ans = ans + ask ( p << 1 , l , mid , x , y ) ; 
    if ( y > mid ) ans = ans + ask ( p << 1 | 1 , mid + 1 , r , x , y ) ; 
    return ans ; 
}
int main ( ) {
    scanf ( "%lld%lld" , & n , & m ) ; 
    for ( int i = 1 ; i <= n ; i ++ ) scanf ( "%lld" , & a[i] ) ; 
    build ( 1 , 1 , n ) ; 
    for ( int i = 1 ; i <= m ; i ++ ) {
        int com , x , y ; scanf ( "%d%d%d" , & com , & x , & y ) ; 
        if ( com == 1 ) {
            LL k ; scanf ( "%lld" , & k ) ; 
            change ( 1 , 1 , n , x , y , k ) ; 
        }
        else printf ( "%lld\n" , ask ( 1 , 1 , n , x , y ) ) ; 
    }
    return 0 ; 
}
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>

using namespace std ;
const int N = 100003 ;
typedef long long ll ;
inline int read ( ) {
    char ch = getchar ( ) ; int res = 0 ;
    while ( ch > '9' || ch < '0' ) ch = getchar ( ) ;
    while ( ch >= '0' && ch <= '9' ) res = res * 10 + ch - 48 , ch = getchar ( ) ;
    return res ;
}

struct L {
    ll mul , val , add ;
} t[N<<2] ;
int n , m , a[N] , mod ;

void build ( int p , int l , int r ) {
    t[p].mul = 1 ;
    if ( l == r ) { t[p].val = a[l] ; return ; }
    int mid = ( l + r ) >> 1 ;
    build ( p << 1 , l , mid ) ;
    build ( p << 1 | 1 , mid + 1 , r ) ;
    t[p].val = ( t[p<<1|1].val + t[p<<1].val ) % mod ;
}

void spread ( int p , int l , int r ) {
    int mid = ( l + r ) >> 1 ; 
    t[p<<1].val = ( t[p<<1].val * t[p].mul + t[p].add * ( mid - l + 1 ) ) % mod ;
    t[p<<1|1].val = ( t[p<<1|1].val * t[p].mul + t[p].add * ( r - mid ) ) % mod ;
    t[p<<1].mul = ( t[p<<1].mul * t[p].mul ) % mod ;
    t[p<<1|1].mul = ( t[p<<1|1].mul * t[p].mul ) % mod ;
    t[p<<1].add = ( t[p<<1].add * t[p].mul + t[p].add ) % mod ;
    t[p<<1|1].add = ( t[p<<1|1].add * t[p].mul + t[p].add ) % mod ;
    t[p].mul = 1 ; t[p].add = 0 ;
}

void change2 ( int p , int l , int r , int x , int y , ll z ) {
    if ( x <= l && r <= y ) {
        t[p].mul = ( t[p].mul * z ) % mod ;
        t[p].add = ( t[p].add * z ) % mod ;
        t[p].val = ( t[p].val * z ) % mod ;
        return ;
    }
    spread ( p , l , r ) ;
    int mid = ( l + r ) >> 1 ;
    if ( x <= mid ) change2 ( p << 1 , l , mid , x , y , z ) ;
    if ( y > mid ) change2 ( p << 1 | 1 , mid + 1 , r , x , y , z ) ;
    t[p].val = ( t[p<<1|1].val + t[p<<1].val ) % mod ;
}

void change1 ( int p , int l , int r , int x , int y , ll z ) {
    if ( x <= l && r <= y ) {
        t[p].add = ( t[p].add + z ) % mod ;
        t[p].val = ( t[p].val + ( r - l + 1 ) * z ) % mod ;
        return ;
    }
    spread ( p , l , r ) ;
    int mid = ( l + r ) >> 1 ;
    if ( x <= mid ) change1 ( p << 1 , l , mid , x , y , z ) ;
    if ( y > mid ) change1 ( p << 1 | 1, mid + 1 , r , x , y , z ) ;
    t[p].val = ( t[p<<1|1].val + t[p<<1].val ) % mod ;
}

ll aska ( int p , int l , int r , int x , int y ) {
    if ( x <= l && r <= y ) return t[p].val ;
        spread ( p , l , r ) ;
        int mid = ( l + r ) >> 1 ;
        ll ans = 0 ;
        if ( x <= mid ) ans += aska ( p << 1 , l , mid , x , y ) ;
        if ( mid < y ) ans += aska ( p << 1 | 1 , mid + 1 , r , x , y ) ;
        ans %= mod ;
        t[p].val = ( t[p<<1|1].val + t[p<<1].val ) % mod ;
        return ans ;
}

int main ( ) {
    n = read ( ) ; m = read ( ) ; mod = read ( ) ;
    for ( int i = 1 ; i <= n ; i ++ ) scanf ( "%d" , & a[i] ) ;
    build ( 1 , 1 , n ) ;
    while ( m -- ) {
        int command = read ( ) ; ll x = read ( ) , y = read ( ) ;
        if ( command == 1 ) change2 ( 1 , 1 , n , x , y , read ( ) ) ;
        if ( command == 2 ) change1 ( 1 , 1 , n , x , y , read ( ) ) ;
        if ( command == 3 ) cout << aska ( 1 , 1 , n , x , y ) << endl ;
    }
}

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>

using namespace std ; 
const int maxn = 2000006 ; 
typedef long long LL ; 
struct L {
    LL val , maxa , maxb , add , mina ;
} t[maxn] ;
LL a[maxn>>2] ; 

void build ( int p , int l , int r ) {
    t[p].mina = 0x7fffffff ; 
    if ( l == r ) { t[p].maxa = t[p].maxb = t[p].val = a[l] ; return ; } 
    int mid = ( l + r ) >> 1 ; 
    build ( p << 1 , l , mid ) ; 
    build ( p << 1 | 1 , mid + 1 , r ) ; 
    t[p].maxa = max ( t[p<<1].maxa , t[p<<1|1].maxa ) ; 
    t[p].maxb = max ( t[p<<1].maxb , t[p<<1|1].maxb ) ; 
    t[p].val = t[p<<1].val + t[p<<1|1].val ; 
}

void spread ( int p , int l , int r ) {
    int mid = ( l + r ) >> 1 ;
    //这里有一定的问题,需要判断变为最小的影响,需要分类讨论,其余的没有问题(大概)
    t[p<<1].val = min ( t[p].mina , ( t[p<<1].val + t[p].add * ( mid - l + 1 ) ) ) ;
    t[p<<1|1].val = min ( t[p].mina , ( t[p<<1|1].val + t[p].add * ( r - mid ) ) ) ;
    //持续到这里
    t[p<<1].maxa = max ( t[p<<1].maxa + t[p].add , t[p].mina == 0x7fffffff ? 0 : t[p].mina ) ;
    t[p<<1|1].maxa = max ( t[p<<1|1].maxa + t[p].add , t[p].mina == 0x7fffffff ? 0 : t[p].mina ) ; 
    t[p<<1].maxb = max( t[p<<1].maxb , t[p<<1].maxa ) ; 
    t[p<<1|1].maxb = max ( t[p<<1|1].maxb , t[p<<1|1].maxa ) ; 
    t[p<<1].add += t[p].add ; t[p<<1|1].add += t[p].add ; 
    t[p<<1].mina = min ( t[p<<1].mina , t[p].mina ) ; t[p<<1|1].mina = min ( t[p<<1|1].mina , t[p].mina ) ; 
    t[p].mina = 0x7fffffff ; t[p].add = 0 ; 
}

void change1 ( int p , int l , int r , int x , int y , LL z ) {
    if ( x <= l && r <= y ) { 
        t[p].maxa += z ; 
        t[p].maxb = max ( t[p].maxb , t[p].maxa ) ; 
        t[p].val = ( t[p].val + z * ( r - l + 1 ) ) ; 
        t[p].add += z ; 
    }
    int mid = ( l + r ) >> 1 ; spread ( p , l , r ) ; 
    if ( x <= mid ) change1 ( p << 1 , l , mid , x , y , z ) ; 
    if ( y > mid ) change1 ( p << 1 | 1 , mid + 1 , r , x , y , z ) ; 
    t[p].val = ( t[p<<1].val + t[p<<1|1].val ) ; 
    t[p].maxa = max ( t[p<<1].maxa , t[p<<1|1].maxa ) ;
    t[p].maxb = max ( t[p<<1].maxb , t[p<<1|1].maxb ) ; 
}

void change2 ( int p , int l , int r , int x , int y , LL z ) {
    if ( x <= l && r <= y ) {
        t[p].val = min ( t[p].val , z * ( r - l + 1 ) ) ; 
        t[p].maxa = min ( z , t[p].maxa ) ;
        t[p].mina = z ; 
        t[p].maxb = max ( t[p].maxb , t[p].maxa ) ; 
    }
    int mid = ( l + r ) >> 1 ; spread ( p , l , r ) ; 
    if ( x <= mid ) change2 ( p << 1 , l , mid , x , y , z ) ; 
    if ( y > mid ) change2 ( p << 1 | 1 , mid + 1 , r , x , y , z ) ; 
    t[p].val = ( t[p<<1].val + t[p<<1|1].val ) ; 
    t[p].maxa = max ( t[p<<1].maxa , t[p<<1|1].maxa ) ; 
    t[p].maxb = max ( t[p<<1].maxb , t[p<<1|1].maxb ) ; 
}

LL ask1 ( int p , int l , int r , int x , int y ) {
    if ( x <= l && r <= y ) { return t[p].val ; }  
    int mid = ( l + r ) >> 1 ; LL ans = 0 ; spread ( p , l , r ) ; 
    if ( x <= mid ) ans += ask1 ( p << 1 , l , mid , x , y ) ; 
    if ( y > mid ) ans += ask1 ( p << 1 | 1 , mid + 1 , r , x , y ) ; 
    return ans ; 
}

LL ask2 ( int p , int l , int r , int x , int y ) {
    if ( x <= l && r <= y ) { return t[p].maxa ; }
    int mid = ( l + r ) >> 1 ; LL ans = 0x7fffffff ; spread ( p , l , r ) ; 
    if ( x <= mid ) ans = min ( ans , ask2 ( p << 1 , l , mid , x , y ) ) ;
    if ( y > mid ) ans = min ( ans , ask2 ( p << 1 | 1 , mid + 1 , r , x , y ) ) ; 
    return ans ; 
}

LL ask3 ( int p , int l , int r , int x , int y ) {
    if ( x <= l && r <= y ) { return t[p].maxb ; } 
    int mid = ( l + r ) >> 1 ; LL ans = 0x7fffffff ; spread ( p , l , r ) ; 
    if ( x <= mid ) ans = min ( ans , ask3 ( p << 1 , l , mid , x , y ) ) ; 
    if ( y > mid ) ans = min ( ans , ask3 ( p << 1 | 1 , mid + 1 , r , x , y ) ) ; 
    return ans ; 
}

int main ( ) {
    int n , m ; scanf ( "%d%d" , & n , & m ) ; 
    for ( int i = 1 ; i <= n ; i ++ ) scanf ( "%lld" , & a[i] ) ; 
    build ( 1 , 1 , n ) ; 
    while ( m -- ) {
        int op ; scanf ( "%d" , & op ) ; 
        if ( op == 1 ) {
            int l , r ; LL k ; scanf ( "%d%d%lld" , & l , & r , & k ) ; 
            change1 ( 1 , 1 , n , l , r , k ) ; 
        }
        else if ( op == 2 ) {
            int l , r ; LL k ; scanf ( "%d%d%lld" , & l , & r , & k ) ; 
            change2 ( 1 , 1 , n , l , r , k ) ; 
        }
        else if ( op == 3 ) {
            int l , r ; scanf ( "%d%d" , & l , & r ) ; 
            printf ( "%lld" , ask1 ( 1 , 1 , n , l , r ) ) ; 
        }
        else if ( op == 4 ) {
            int l , r ; scanf ( "%d%d" , & l , & r ) ;    
            printf ( "%lld" , ask2 ( 1 , 1 , n , l , r ) ) ;
        }
        else {
            int l , r ; scanf ( "%d%d" , & l , & r ) ;    
            printf ( "%lld" , ask3 ( 1 , 1 , n , l , r ) ) ;
        }
    }
    return 0 ; 
}