快速幂和矩阵快速幂

╰+攻爆jí腚メ 2022-06-01 08:10 421阅读 0赞

前言

新年第一篇技术类的文章,应该算是算法方面的文章的。看标题:快速幂和矩阵快速幂,好像挺高大上。其实并不是很难,快速幂就是快速求一个数的幂(一个数的 n 次方)。

快速幂

首先,来看一下幂,我们知道,假设有一个整数 x, 如果我们要求出 x^n (即为 x 的 n 次方)的值,最容易想到的办法就是循环相乘(这里不考虑整数溢出的情况下),于是我们很容易就可以写出下面的代码:

  1. int res = 1;
  2. for (int i = 0; i < n; i++) {
  3. res *= x;
  4. }

咋一看,嗯,很正常的代码。确实是挺正常的代码,其时间复杂度为 O(n)。其实这个问题的时间复杂度可以降到 O(logn) 。那么问题来了,怎么做到的? 其实,就是通过快速幂的方法。

先来举个例子:假设我们现在要求出 5^9 的值,不用我们刚刚直接循环的方法,换种思维,我们可以这样看:

  1. 5^9 = 5*5^8 = 5*((5^4)^2) = 5*(5^4)*(5^4)
  2. 5^4 = ((5^2)^2) = (5^2)*(5^2)
  3. 5^2 = 5*5

如果当前的指数是偶数,我们把指数拆成两半,得到两个相同的数,然后把这两个相同的数相乘,可以得到原来的数;
如果当前的指数是奇数,我们把指数拆成两半,得到两个相同的数,此时还剩余一个底数,把这两个相同的数和剩余的底数这三个数相乘,可以得到原来的数。

那么如果说我们按照这种思路去计算 5^9 的值的话,我们会发现只需要执行 3 次计算。相比原来的直接用循环的 9 次计算,正好是 log9 的整数部分值。Ok,那么怎么用代码写出来呢?这里先给出代码,再做解释:

  1. /**
  2. * 计算 x^n 的值,并将结果保存在 res 中
  3. */
  4. long long res = 1;
  5. // 进行快速幂运算,n 为当前的指数值,n 为 0 的时候运算结束
  6. while (n) {
  7. // 用位运算的方式判断 n 是否为奇数,速度更快,等价于 n%2
  8. if (n & 1) {
  9. // 如果 n 是奇数,那么需要将 x 存入运算结果中
  10. res *= x;
  11. }
  12. // 更新当前的 x 的值
  13. x *= x;
  14. // 用位运算的方式进行 n/2,速度更快,等价于 n/=2
  15. n >>= 1;
  16. }

首先,我们注意到,不管当前的指数值(n 的值)是奇数还是偶数,一次运算之后 n 都要拆成两半(n /= 2),所以,我们在每次运算的时候都要让当前的 x *= x ,也就是执行 x = x^2,这点相信不难理解。

第二,当 n 为奇数的时候,如果执行 n /= 2,结果会使得 n 损失一个 1。举个例子:假设此时 n = 9,9 / 2 = 4 ,即使我们之后会执行 x *= x,也只是把 n 的一半 (4) 补回来了,还少了个 1 (4+4+1 = 9)。因此此时要把少了的那一个 x 存入结果中,即为执行 res *= x;

第三,只要 n 的初始值是大于 0 的(其余的数需要特殊处理),那么在运算过程中一直执行 n >>= 1,也就是将 n 除以 2 ,n 是一定会等于 1 的,此时执行 res *= x,将最后的结果保存在 res 中,之后退出循环。

最后,整个循环每一次执行 n 都变成原来的一半,当 n 等于 0 的时候结束,时间复杂度为 O(logn)

这里给出一个快速幂的完整代码:

  1. /**
  2. * Describe:实现快速幂
  3. * Author:指点
  4. * Date:2018/1/24
  5. */
  6. #include <iostream>
  7. #include <cstdlib>
  8. using namespace std;
  9. // 使用快速幂求出 x^n 的值并返回,不考虑高精度,请控制参数范围
  10. double myPow(double x, int n) {
  11. // 任何不是 0 的数的 0 次幂为 1
  12. if (x && n == 0) {
  13. return 1;
  14. } else if (x == 0 && n == 0) {
  15. exit(1);
  16. }
  17. // 如果 n 是负数,那么返回结果要进行处理
  18. bool nIsNegative = false;
  19. if (n < 0) {
  20. nIsNegative = true;
  21. n = -n;
  22. }
  23. double res = 1;
  24. while (n) {
  25. // 用位运算的方式判断 n 是否为奇数,速度更快,等价于 n%2
  26. if (n & 1) {
  27. res *= x;
  28. }
  29. x *= x;
  30. // 用位运算的方式进行 n/2,速度更快,等价于 n/=2
  31. n >>= 1;
  32. }
  33. // n 是负数?1.0/res 否则 res
  34. return nIsNegative ? 1.0/res : res;
  35. }
  36. int main() {
  37. double x;
  38. int n;
  39. while (cin >> x >> n) {
  40. cout << myPow(x, n) << endl << endl;
  41. }
  42. return 0;
  43. }

来看看结果:

这里写图片描述

理解了上面的几点,相信快速幂就难不到你了。下面来看看矩阵快速幂:

矩阵快速幂

其实矩阵快速幂的思想是和快速幂一样的,矩阵快速幂是用于快速求出一个矩阵的 n 次方的方法

首先,我们要知道,两个矩阵能不能相乘是有一定条件的:
假设有两个矩阵 A, B。如果矩阵 A 的列数等于矩阵 B 的行数,那么这两个矩阵才可以进行相乘,否则这两个矩阵是不能相乘的。
对于这里,我们要求的是一个矩阵的 n 次方,那么既然是同一个矩阵,那么只有当其为方阵(行数和列数相同的矩阵)的时候,才可以相乘。矩阵相乘结果也是一个矩阵,具体的规则为:如果矩阵 A 的列数等于矩阵 B 的行数,假设矩阵 C = A*B, 那么矩阵 C 的行数和矩阵 A 的行数相等,矩阵 C 的列数和矩阵 B 相等。矩阵 C 的第一行第一列元素等于矩阵 A 的第一行的元素和矩阵 B 的第一列的元素依次相乘再求和。矩阵 C 的第一行第二列元素等于矩阵 A 的第一行的元素和矩阵 B 的第二列的元素依次相乘再求和。。。。。。矩阵 C 的第 n 行第 m 列元素等于矩阵 A 的第 n 行的元素和矩阵 B 的第 m 列的元素依次相乘再求和。依次类推。

这里给出一个求出两矩阵相乘的结果的函数:

  1. // 计算矩阵 a(m*s 规模) 和矩阵 b(s*n 规模) 相乘的结果,并将结果返回
  2. int **matrixMultiply(int **a, int **b, int m, int s, int n) {
  3. // 初始化储存结果的数组
  4. int **result = new int*[m];
  5. for (int i = 0; i < m; i++) {
  6. result[i] = new int[n];
  7. memset(result[i], 0, sizeof(int)*n);
  8. }
  9. // 进行矩阵相乘计算
  10. for (int i = 0; i < m; i++) {
  11. for (int j = 0; j < n; j++) {
  12. for (int k = 0; k < s; k++) {
  13. result[i][j] += a[i][k]*b[k][j];
  14. }
  15. }
  16. }
  17. return result;
  18. }

这里用的是二级指针作为参数和返回值来表示对应的矩阵。来测试一下这个函数:

  1. /**
  2. * Describe:实现矩阵相乘
  3. * Author:指点
  4. * Date:2018/1/24
  5. */
  6. #include <iostream>
  7. #include <cstring>
  8. using namespace std;
  9. // 计算矩阵 a(m*s 规模) 和矩阵 b(s*n 规模) 相乘的结果,并将结果返回
  10. int **matrixMultiply(int **a, int **b, int m, int s, int n) {
  11. // 初始化储存结果的数组
  12. int **result = new int*[m];
  13. for (int i = 0; i < m; i++) {
  14. result[i] = new int[n];
  15. memset(result[i], 0, sizeof(int)*n);
  16. }
  17. // 进行矩阵相乘计算
  18. for (int i = 0; i < m; i++) {
  19. for (int j = 0; j < n; j++) {
  20. for (int k = 0; k < s; k++) {
  21. result[i][j] += a[i][k]*b[k][j];
  22. }
  23. }
  24. }
  25. return result;
  26. }
  27. int main() {
  28. int m = 2, s = 3, n = 2;
  29. // 初始化 a 、b 两个矩阵
  30. int **a = new int*[m];
  31. for (int i = 0; i < m; i++) {
  32. a[i] = new int[s];
  33. }
  34. int **b = new int*[s];
  35. for (int i = 0; i < s; i++) {
  36. b[i] = new int[n];
  37. }
  38. cout << "a 矩阵:" << endl;
  39. for (int i = 0; i < m; i++) {
  40. for (int j = 0; j < s; j++) {
  41. a[i][j] = i + j;
  42. cout << a[i][j] << " ";
  43. }
  44. cout << endl;
  45. }
  46. cout << "b 矩阵:" << endl;
  47. for (int i = 0; i < s; i++) {
  48. for (int j = 0; j < n; j++) {
  49. b[i][j] = i + j;
  50. cout << b[i][j] << " ";
  51. }
  52. cout << endl;
  53. }
  54. int **res = matrixMultiply(a, b, 2, 3, 2);
  55. // 结果是一个 2 行 2 列的数组
  56. cout << "相乘的结果矩阵:" << endl;
  57. for (int i = 0; i < 2; i++) {
  58. for (int j = 0; j < 2; j++) {
  59. cout << res[i][j] << " ";
  60. }
  61. cout << endl;
  62. }
  63. // 释放申请的内存空间
  64. if (a != NULL) {
  65. for (int i = 0; i < m; i++) {
  66. delete[] a[i];
  67. }
  68. delete[] a;
  69. a = NULL;
  70. }
  71. if (b != NULL) {
  72. for (int i = 0; i < s; i++) {
  73. delete[] b[i];
  74. }
  75. delete[] b;
  76. b = NULL;
  77. }
  78. if (res != NULL) {
  79. for (int i = 0; i < m; i++) {
  80. delete[] res[i];
  81. }
  82. delete[] res;
  83. res = NULL;
  84. }
  85. return 0;
  86. }

来看一下结果:

这里写图片描述

Ok,给定数据测试正确,有了这个函数,我们写矩阵快速幂的代码就简单了,我们把矩阵看成一个数,矩阵乘法的函数我们已经写好了,那么我们仿照快速幂的写法,实现矩阵快速幂:

  1. /**
  2. * Describe:实现矩阵快速幂
  3. * Author:指点
  4. * Date:2018/1/24
  5. */
  6. #include <iostream>
  7. #include <cstring>
  8. using namespace std;
  9. // 删除数组空间的函数,数组行数:m
  10. void deleteArray(int **a, int m) {
  11. if (a != NULL) {
  12. for (int i = 0; i < m; i++) {
  13. delete[] a[i];
  14. }
  15. delete[] a;
  16. a = NULL;
  17. }
  18. }
  19. // 计算矩阵 a(m*s 规模) 和矩阵 b(s*n 规模) 相乘的结果,并将结果返回
  20. int **matrixMultiply(int **a, int **b, int m, int s, int n) {
  21. // 初始化储存结果的数组
  22. int **result = new int*[m];
  23. for (int i = 0; i < m; i++) {
  24. result[i] = new int[n];
  25. memset(result[i], 0, sizeof(int)*n);
  26. }
  27. // 进行矩阵相乘计算
  28. for (int i = 0; i < m; i++) {
  29. for (int j = 0; j < n; j++) {
  30. for (int k = 0; k < s; k++) {
  31. result[i][j] += a[i][k]*b[k][j];
  32. }
  33. }
  34. }
  35. return result;
  36. }
  37. // 用快速幂求出矩阵 a(m*m 规模,只有方阵才可以自我相乘) 的 n 次方,并将结果返回
  38. int **myMatrixPow(int **a, int m, int n) {
  39. // 初始化保存结果的矩阵
  40. int **res = new int*[m];
  41. for (int i = 0; i < m; i++) {
  42. res[i] = new int[m];
  43. memset(res[i], 0, sizeof(int)*m);
  44. // 保存结果的矩阵初始应该是一个单位矩阵(正向斜对角线值为 1,其余为 0)
  45. res[i][i] = 1;
  46. }
  47. // 保存要删除的数组空间的指针
  48. int **oldPoint = NULL;
  49. while (n) {
  50. if (n & 1) {
  51. // 保存 res 指针当前的内存地址
  52. oldPoint = res;
  53. // res 指向储存矩阵相乘结果的数组的地址
  54. res = matrixMultiply(res, a, m, m, m);
  55. // 删除 res 指针原有的内存空间
  56. deleteArray(oldPoint, m);
  57. }
  58. // 保存 a 指针当前的内存地址
  59. oldPoint = a;
  60. // a 指向储存矩阵相乘结果的数组的地址
  61. a = matrixMultiply(a, a, m, m, m);
  62. // 删除 a 指针原有的内存空间
  63. deleteArray(oldPoint, m);
  64. n >>= 1;
  65. }
  66. return res;
  67. }
  68. int main() {
  69. int m = 2;
  70. // 初始化 a 方阵
  71. int **a = new int*[m];
  72. for (int i = 0; i < m; i++) {
  73. a[i] = new int[m];
  74. }
  75. cout << "a 矩阵:" << endl;
  76. for (int i = 0; i < m; i++) {
  77. for (int j = 0; j < m; j++) {
  78. a[i][j] = i + j;
  79. cout << a[i][j] << " ";
  80. }
  81. cout << endl;
  82. }
  83. cout << endl;
  84. for (int i = 0; i < 10; i++) {
  85. // 计算结果
  86. int **res = myMatrixPow(a, m, i);
  87. cout << "a 矩阵的 " << i << " 次方计算结果:" << endl;
  88. for (int i = 0; i < 2; i++) {
  89. for (int j = 0; j < 2; j++) {
  90. cout << res[i][j] << " ";
  91. }
  92. cout << endl;
  93. }
  94. // 释放 res 指针的内存空间
  95. deleteArray(res, m);
  96. }
  97. // 最后释放 a 指针的内存空间
  98. deleteArray(a, m);
  99. return 0;
  100. }

关键函数就是 myMatrixPow ,我想有了快速幂的基础,这个函数也不难理解了。代码里面有较多的指针操作,所以专门写了一个函数 deleteArray 来释放程序运行过程中所申请的堆内存空间,其实不主动释放,等程序结束后让操作系统回收也是可以的,不过个人有点强迫症…..哈哈。看代码不难理解利用矩阵快速幂求方阵的幂的时间复杂度为O(m^3*logn),m为方阵的行数和列数(方阵相乘的复杂度为 O(m^3),快速幂的复杂度为 O(logn) )。
好了, 来看一下结果:

这里写图片描述

如果有兴趣的话,你可以自己验算一下结果的正确性。

应用

那么看了这么多,快速幂有啥子用呢?
首先对于求一个数的 n 次方,可以用 O(logn) 的时间复杂度来求出结果,这肯定是一个用途,那么矩阵快速幂呢?
不知道你还记不记得斐波那契数列的递推公式,斐波那契数列的递推公式可以写成:

  1. 如果 n > 2,那么 f(n) = f(n-1) + f(n-2);
  2. 如果 n = 2 或者 n = 1,那么 f(n) = 1

那么如果现在要求 f(n) 的值呢,根据递推公式我们可以很快的写出下面的代码(不考虑整数溢出的情况):

  1. typedef long long ll;
  2. ll getFibo(int n) {
  3. if (n == 1 || n == 2) {
  4. return 1;
  5. }
  6. return getFibo(n-1) + getFibo(n-2);
  7. }

这个代码的时间复杂度大约是 O(2^n),其执行过程就是一颗二叉树,里面进行了很多的重复运算。
当然也有循环版本的(不考虑整数溢出的情况):

  1. typedef long long ll;
  2. ll getFibo(int n) {
  3. ll first = 1, second = 1, res = 0;
  4. for (int i = 3; i <= n; i++) {
  5. res = first + second;
  6. first = second;
  7. second = res;
  8. }
  9. return res;
  10. }

这个代码的时间复杂度为 O(n),比递归的方法好。
这两种方法都可以求解,但是可以有更高效的方法,就是利用矩阵快速幂。
不过咋一看这怎么和矩阵快速幂联系到一起呢?要用矩阵快速幂,我们得先有矩阵:

假设我们现在有一个一行两列的矩阵:A【f(n-2), f(n-1)】,我们设定一个 2*2 的矩阵 T,使得矩阵 A*T 相乘的结果等于另外一个一行两列的矩阵 C:【f(n-1), f(n)】。
我们根据给定条件和斐波那契的递推公式可以很容易构造出矩阵 T:

0 1
1 1

构造过程就是矩阵 A*T 的计算过程:

【f(n-2)*0 + f(n-1)*1 = f(n-1), f(n-2)*1 + f(n-1)*1 = f(n)】

Ok,那么我们知道 【f(n-2), f(n-1)】* T = 【f(n-1), f(n)】,
那么可以推出:【f(n-3), f(n-2)】* T*T = 【f(n-1), f(n)】,【f(n-4), f(n-3)】* T*T*T = 【f(n-1), f(n)】…….
也就是:【f(1), f(2)】 * T^(n-2) = 【f(n-1), f(n)】,
f(1)=1, f(2)=1, 也就是:【1, 1】*T^(n-2) = 【f(n-1), f(n)】

现在在看一下我们是不是有了 T^(n-2) 这个矩阵求幂的条件,那么我们就可以用矩阵快速幂来求解这道题了:

  1. /**
  2. * Describe:利用矩阵快速幂求斐波那契数列的第 n 项值
  3. * Author:指点
  4. * Date:2018/1/24
  5. */
  6. #include <iostream>
  7. #include <cstring>
  8. using namespace std;
  9. typedef long long ll;
  10. // f(1) 和 f(2) 的值
  11. const ll START[] = {
  12. 1, 1};
  13. // 矩阵 T
  14. ll **T = NULL;
  15. // 删除数组空间的函数,数组行数:m
  16. void deleteArray(ll **a, int m) {
  17. if (a != NULL) {
  18. for (int i = 0; i < m; i++) {
  19. delete[] a[i];
  20. }
  21. delete[] a;
  22. a = NULL;
  23. }
  24. }
  25. // 计算矩阵 a(m*s 规模) 和矩阵 b(s*n 规模) 相乘的结果,并将结果返回
  26. ll **matrixMultiply(ll **a, ll **b, int m, int s, int n) {
  27. // 初始化储存结果的数组
  28. ll **result = new ll*[m];
  29. for (int i = 0; i < m; i++) {
  30. result[i] = new ll[n];
  31. memset(result[i], 0, sizeof(ll)*n);
  32. }
  33. // 进行矩阵相乘计算
  34. for (int i = 0; i < m; i++) {
  35. for (int j = 0; j < n; j++) {
  36. for (int k = 0; k < s; k++) {
  37. result[i][j] += a[i][k]*b[k][j];
  38. }
  39. }
  40. }
  41. return result;
  42. }
  43. // 求出矩阵 a(m*m 规模,只有方阵才可以自我相乘) 的 n 次方,并将结果返回
  44. ll **myMatrixPow(ll **a, int m, int n) {
  45. // 初始化保存结果的矩阵
  46. ll **res = new ll*[m];
  47. for (int i = 0; i < m; i++) {
  48. res[i] = new ll[m];
  49. memset(res[i], 0, sizeof(ll)*m);
  50. // 保存结果的矩阵初始应该是一个单位矩阵(正向斜对角线值为 1,其余为 0)
  51. res[i][i] = 1;
  52. }
  53. // 保存要删除的数组空间的指针
  54. ll **oldPoint = NULL;
  55. while (n) {
  56. if (n & 1) {
  57. // 保存 res 指针当前的内存地址
  58. oldPoint = res;
  59. // res 指向储存矩阵相乘结果的数组的地址
  60. res = matrixMultiply(res, a, m, m, m);
  61. // 删除 res 指针原有的内存空间
  62. deleteArray(oldPoint, m);
  63. }
  64. // 保存 a 指针当前的内存地址
  65. oldPoint = a;
  66. // a 指向储存矩阵相乘结果的数组的地址
  67. a = matrixMultiply(a, a, m, m, m);
  68. // 删除 a 指针原有的内存空间
  69. deleteArray(oldPoint, m);
  70. n >>= 1;
  71. }
  72. return res;
  73. }
  74. // 求出斐波那契数列的第 n 项的值,不考虑整数溢出,请控制数字范围
  75. ll getFibo(int n) {
  76. if (n == 1 || n == 2) {
  77. return 1;
  78. }
  79. ll res = 0;
  80. // 求出 T 矩阵的 n-2 次方(T^n-2)的值,并将结果保存在 T 指针中
  81. T = myMatrixPow(T, 2, n-2);
  82. // 求出最后的 f(n) 的值(res += START[i]*T[i][0] 为 f(n-1) 的值,res += START*T[i][1] 为 f(n) 的值)
  83. for (int i = 0; i < 2; i++) {
  84. res += START[i]*T[i][1];
  85. }
  86. return res;
  87. }
  88. int main() {
  89. // 初始化矩阵 T,元素值通过计算求得
  90. T = new ll*[2];
  91. T[0] = new ll[2];
  92. T[1] = new ll[2];
  93. for (int i = 1; i < 80; i++) {
  94. /**
  95. * 矩阵 T 元素值:
  96. * 0 1
  97. * 1 1
  98. */
  99. T[0][0] = 0;
  100. T[0][1] = T[1][0] = T[1][1] = 1;
  101. cout << "第" << i << "项斐波那契数列的值:";
  102. cout << getFibo(i) << endl;
  103. }
  104. deleteArray(T, 2);
  105. return 0;
  106. }

和矩阵快速幂差不多的代码,如果你理解了矩阵快速幂的思想的话,我想这代码也很好理解,这里我们可以看到,用这种方法求斐波那契数列的时间复杂度约为 O(2^3*logn),也就是求矩阵的幂的时间复杂度。忽略常数,即为O(logn)。
有图有真相,最后来看一下结果:

这里写图片描述

其实类似于斐波那契数列这种利用递推式来求值的问题都可以通过矩阵快速幂来解决,这其中主要的问题就是怎么构造那个矩阵。关于这点,可以参考下这篇文章:
http://www.cnblogs.com/frog112111/archive/2013/05/19/3087648.html

如果说练习题的话,可以试试下面的:
http://poj.org/problem?id=3070
http://lx.lanqiao.cn/problem.page?gpid=T396()

Ok,如果博客中有什么不正确的地方,请多多指点,如果觉得本文对您有帮助,请不要吝啬您的赞。
谢谢观看。。。

发表评论

表情:
评论列表 (有 0 条评论,421人围观)

还没有评论,来说两句吧...

相关阅读

    相关 矩阵快速

    昨天晚上矩阵小王子给我们讲了一下矩阵快速幂,学习了一下,写了一个模板。 1:思想 矩阵快速幂的思想就是跟数的快速幂一样,假如我们要求2^11,次方,我们可以把 11 写成

    相关 矩阵快速

    A为一个方阵,则Tr A表示A的迹(就是主对角线上各项的和),现要求Tr(A^k)%9973。  Input 数据的第一行是一个T,表示有T组数据。  每组数据的

    相关 快速矩阵快速

    前言 新年第一篇技术类的文章,应该算是算法方面的文章的。看标题:快速幂和矩阵快速幂,好像挺高大上。其实并不是很难,快速幂就是快速求一个数的幂(一个数的 n 次方)。

    相关 矩阵快速

    通常我们使用的快速幂是以二为底的,这次就遇到了一道以10为底的快速幂题目; 先说下快速幂 long long power(long long a,long long