倍增 & ST表

倍增

我们在学习二进制的时候就认识到,任意一个数均可以用二进制来表示,这就是倍增思想的原理,对于区间或距离就可以采用此思想进行优化。

例题:P699 查找位置(倍增)

1.分析:

首先我们知道任意一个正整数都能表示为二进制,如

13=1101=123+122+021+12013 = 1101 = 1*2^3 + 1*2^2 + 0*2^1 + 1*2^0 假设我们要查找第一个>=k的答案位置是13,怎么利用上面的公式计算它呢?

首先起始位置是024=160,2^4 = 16,询问这个位置 a[16] >= k?

如果不是,表示答案在16后面,那么起始位置跳到16再往下循环 如果是,表示答案在16前面,那么继续询问 23=82^3=8 的位置 当然按照题目,a[16]肯定>=k,因此不跳,继续询问a[8] >= k?

显然a[8]<k(正确位置是13),则跳到8的位置,继续询问a[8+4] >= k?

现在a[8+4]<k,则跳到12的位置,继续询问a[12+2]>=k?

现在a[12+2]>=k,不跳,继续询问a[12+1]>=k?

现在a[12+1]>=k, 不跳,考虑到1已经是最小,因此答案就是12+1=13。

2、倍增算法

首先计算lgn,为最接近n的 2x2^x 位置 i从lgn往下循环: 判定 i 位置是否满足,是就跳过(表示在答案后面) 否则,跳到i位置i 循环结束,最后跳的位置就是答案的前面一个 伪代码:

int p=0;
for(int i=lgn; i>=0; i--) {
  int delta = 1<<i;
  if(a[p+delta] >= k) continue;
  p += delta;
}
return p+1;

ac代码:

#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;

const int N = 5e5 + 5;
int a[N];
int n, m, q;
void handle(int k)
{
  int p = 0;
  int t = 0;
  while(1 << t <= n)
  {
    t++;
  }
  for(int i = t - 1; i >= 0; i--)
  {
    int delta = 1 << i;
    if(p + delta > n || a[p + delta] >= k) continue;
    p += delta;
  }
  printf("%d ", p + 1 );
}
void initInput(void)
{
  scanf("%d%d", &n, &m);
  for(int i = 1; i <= n; i++)
  {
    scanf("%d", &(a[i]));
  }
  int k;
 	 for(int i = 1; i <= m; i++)
   {
     scanf("%d", &k);
     handle(k);
   }
}

int main(void)
{
  initInput();
  
  return 0;
}

例题:Balanced Lineup G

题目描述

每天,农夫 John 的 n(1n5×104)n(1\le n\le 5\times 10^4) 头牛总是按同一序列排队。

有一天, John 决定让一些牛们玩一场飞盘比赛。他准备找一群在队列中位置连续的牛来进行比赛。但是为了避免水平悬殊,牛的身高不应该相差太大。John 准备了 (1q1.8×105)(1\le q\le 1.8\times10^5) 个可能的牛的选择和所有牛的身高 hi(1hi106,1in)h_i(1\le h_i\le 10^6,1\le i\le n)。他想知道每一组里面最高和最低的牛的身高差。

输入格式

第一行两个数n,qn,q

接下来 n 行,每行一个数 hih_i

再接下来 q 行,每行两个整数 a 和 b,表示询问第 a 头牛到第 b 头牛里的最高和最低的牛的身高差。

输出格式

输出共 q 行,对于每一组询问,输出每一组中最高和最低的牛的身高差。

输入输出样例

输入 #1

6 3
1
7
3
4
2
5
1 5
4 6
2 2

输出 #1

6
3
0

题目来源

Balanced Lineup G

分析:

考虑将题目要求转化为求区间[a,b][a, b]内的最大值与最小值之差,可以考虑采用枚举暴力的方式,不过其时间复杂度为$O(nq) = 1.8 \times 10 ^ 5 \times 5 \times 10^4 = 9 \times 10 ^9$,肯定有部分测试点超时TLE(Time Limit Exceeded)。

参考代码:

#include<iostream>
using namespace std;
const int N = 5e4 + 10;
const int INF = 0x3f3f3f3f;  // 代表无穷大
int n, q;
int c[N];

void calc(int a, int b)
{
    int maxx = -INF, minx = INF;
    for (int i = a; i <= b; ++i)
    {
        maxx = max(maxx, c[i]);
        minx = min(minx, c[i]);
    }
    cout << maxx - minx << endl;
}

void initInput()
{
    // 提高cin的速度
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> q;
    for (int i = 1; i <= n; i++)
    {
        cin >> c[i];
    }
    for (int i = 0; i < q; i++)
    {
        int a, b;
        cin >> a >> b;
        calc(a, b);
    } 
}

int main(void)
{
    initInput();
    return 0;
}

经测试,60%的测试点TLE。我们考虑一下如何优化这个过程。

fMax[i][j]fMax[i][j]表示[i, j] 区间内的最大值,则有

  • i == j 时,fMax[i][j]=c[j]fMax[i][j] = c[j];
  • i != j 时,fMax[i][j]=max(fMax[i][j1],c[j])fMax[i][j] = \max(fMax[i][j-1], c[j]);

同理,fMin[i][j]fMin[i][j]也如此。

参考代码:

#include<iostream>
using namespace std;
const int N = 5e4 + 10;
const int INF = 0x3f3f3f3f;  // 代表无穷大
int n, q;
int c[N];

int fMax[N][N]; // fMax[i][j] 表示[i,j]范围内的最大值
int fMin[N][N]; // fMin[i][j] 表示[i,j]范围内的最小值

void handlePre()
{
    for (int i = 1; i <= n; i++)
    {
        for (int j = i; j <= n; ++j)
        {
            if(i == j)
            {
                fMax[i][j] = fMin[i][j] = c[j];
            }
            else
            {
                fMax[i][j] = max(fMax[i][j - 1], c[j]);
                fMin[i][j] = min(fMin[i][j - 1], c[j]);
            }
        }
    }
}
void calc2(int a, int b)
{
    cout << fMax[a][b] - fMin[a][b] << endl;
}
void calc(int a, int b)
{
    int maxx = -INF, minx = INF;
    for (int i = a; i <= b; ++i)
    {
        maxx = max(maxx, c[i]);
        minx = min(minx, c[i]);
    }
    cout << maxx - minx << endl;
}

void initInput()
{
    // 提高cin的速度
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> q;
    
    for (int i = 1; i <= n; i++)
    {
        cin >> c[i];
    }
    handlePre();
    for (int i = 0; i < q; i++)
    {
        int a, b;
        cin >> a >> b;
        // calc(a, b);
        calc2(a, b);
    } 
}

int main(void)
{
    initInput();
    return 0;
}

则其时间复杂度为O(n2+q)=25×108+1.8×105O(n^2 + q) = 25 \times 10 ^ 8 + 1.8 \times 10 ^ 5,同样存在超时的问题,空间复杂度为O(n2)=25×108O(n^2) = 25 \times 10 ^ 8,存在超内存的问题。这样的预处理还是复杂度太高。考虑进一步优化,将前置处理的复杂度降低下来。可以采用倍增思想将预处理分析如下:

fMax[i][j]fMax[i][j]表示[i,i+2j1][i, i + 2^j -1] 区间内的最大值,即表示 从i位置开始的连续2j2^j个数的最大值(倍增思想),则有

  • fMax[i][0]=c[i]fMax[i][0] = c[i] ;(从不开始的1个的最大值,显然是c[i]c[i]本身。
  • $fMax[i][j] = \max(fMax[i][j-1], fMax[i+2^{j -1}][j - 1])$;

fMax[i][j1]fMax[i][j-1]表示i开始连续的2j12^{j - 1}的最大值,其区间为 [i,i+2j11][i,i + 2^{j - 1} - 1],则fMax[i+2j1][j1]fMax[i+2^{j -1}][j - 1]表示i+2j1i+2^{j - 1}开始连续的2j12^{j - 1}个数的最大值,其区间为[i+2j1,i+2j][i + 2 ^ {j - 1},i + 2^j]

对于每一个起点i,都有O(log2N)O(\log{_2}{N})个区间,查询每个区间都可以O(1)O(1)实现,则预处理的时间与空间复杂度都为O(nlogn)O(n\log{n}),此题的算法总复杂度就变为了$O(n \log{n} + q) = 5 \times 10^4 \times \log(5 \times 10 ^4) + 1.8 \times 10 ^ 5 \approx 10^5$。

同理,fMin[i][j]fMin[i][j]也如此。

预处理代码参考如下:

void handlePre2() // 预处理2
{
    for (int i = 1; i <= n; i++)
    {
        lg[i] = lg[i / 2] + 1; // 处理log(n)
        fMax[i][0] = c[i];
        fMin[i][0] = c[i];
    }
    for (int j = 1; j <= lg[n];  j++)
    {
        for (int i = 1; i + (1 <<( j - 1))<= n; ++i)
        {
            fMax[i][j] = max(fMax[i][j - 1], fMax[i + (1 << (j - 1))][j - 1]);
            fMin[i][j] = min(fMin[i][j - 1], fMin[i + (1 << (j - 1))][j - 1]);
        }
    }
}

那么怎么实现每个任意区间的查询呢?

假设查询的区间为[i,j][i, j],显然我们可以将此区间分为两个重叠或刚好不重叠的区间[i,i+2(log2(ji+1))][i, i + 2^{(log_2{(j-i+1)})}][j2(log2(ji+1),j][j-2^{(\log_2 (j - i +1)}, j]

ji+1j-i + 1为区间的长度,显然两个区间重叠时不影响从这两个区间中找到最值,例如求最大值则为$ max(fMax[i][lg[j - i + 1]], fMax[j - (1 << lg[j -i + 1]) + 1][lg[j - i + 1]])$。

查询代码参考如下:

void query(int a, int b)
{
    int maxx =0, minx =0;
    int len = b - a + 1;
    maxx = max(fMax[a][lg[len]], fMax[b - (1 << lg[len]) + 1][lg[len]]);
    minx = min(fMin[a][lg[len]], fMin[b - (1 << lg[len]) + 1][lg[len]]);
    cout << maxx - minx << endl;
}

最终的参考代码:

#include <iostream>
using namespace std;
const int N = 5e4 + 5;
const int INF = 0x3f3f3f3f; // 代表无穷大
int n, q;
int c[N];

int fMax[N][300]; // fMax[i][j] 表示[i,i+2^j]范围内的最大值
int fMin[N][300]; // fMin[i][j] 表示[i,i+2^j]范围内的最小值

int lg[N] = {-1};

void handlePre() // 预处理
{
    for (int i = 1; i <= n; i++)
    {
        for (int j = i; j <= n; ++j)
        {
            if (i == j)
            {
                fMax[i][j] = fMin[i][j] = c[j];
            }
            else
            {
                fMax[i][j] = max(fMax[i][j - 1], c[j]);
                fMin[i][j] = min(fMin[i][j - 1], c[j]);
            }
        }
    }
}
void handlePre2() // 预处理2
{
    for (int i = 1; i <= n; i++)
    {
        lg[i] = lg[i / 2] + 1; // 处理log(n)
        fMax[i][0] = c[i];
        fMin[i][0] = c[i];
    }
    for (int j = 1; j <= lg[n]; j++)
    {
        for (int i = 1; i + (1 << (j - 1)) <= n; ++i)
        {
            fMax[i][j] = max(fMax[i][j - 1], fMax[i + (1 << (j - 1))][j - 1]);
            fMin[i][j] = min(fMin[i][j - 1], fMin[i + (1 << (j - 1))][j - 1]);
        }
    }
}
void query(int a, int b)
{
    int maxx = 0, minx = 0;
    int len = b - a + 1;
    maxx = max(fMax[a][lg[len]], fMax[b - (1 << lg[len]) + 1][lg[len]]);
    minx = min(fMin[a][lg[len]], fMin[b - (1 << lg[len]) + 1][lg[len]]);
    cout << maxx - minx << endl;
}
void calc2(int a, int b)
{
    cout << fMax[a][b] - fMin[a][b] << endl;
}
void calc(int a, int b)
{
    int maxx = -INF, minx = INF;
    for (int i = a; i <= b; ++i)
    {
        maxx = max(maxx, c[i]);
        minx = min(minx, c[i]);
    }
    cout << maxx - minx << endl;
}

void initInput()
{
    // 提高cin的速度
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> q;

    for (int i = 1; i <= n; i++)
    {
        cin >> c[i];
    }
    handlePre2();
    for (int i = 0; i < q; i++)
    {
        int a, b;
        cin >> a >> b;
        // calc(a, b);
        // calc2(a, b);
        query(a, b);
    }
}

int main(void)
{
    initInput();
    return 0;
}

解决区间最值问题的这个解决方法中就利用了倍增的思想,我们依次从20,21,22,...,2log2n2^0,2^1,2^2,...,2^{\log _2 n}来打表(此表又称为ST表)找到其中的最值,这是与二分的思想相似的一种思想。