矩阵学习笔记2 矩阵加速递推【矩阵】
点击量:174
在学习笔记1中的矩阵乘法公式中,我们看出,每个位置的值都是多个积之和。那么如果矩阵可以加速这样的快速幂/乘法,那么就同理可以加速加法了。因此对于那些有递推关系的式子,我们就可以用矩阵加速。
如果我们有线性的递推式子,那么我们就可以把加法变为乘法。因为在矩阵中每一种线性组合都会出现,那么我们只需要构造一个矩阵A,使得$ [ f(n) \ … \ ]=[ f(n-1) \ …\ ] \times A$,这里省略号里要尽量使相邻两个下标在数值上也相邻,这样递推式好构造一些。
e.g.
比如斐波那契数列,$ f_n=f_{n-1}+f_{n-2}$。我们就可以试着把矩阵$ [ f_n \ f_{n-1} ] $与$ [ f_{n-1} \ f_{n-2} ] $之间差的那个矩阵补出来,两矩阵大小一定要相同,不然无法递推。我们来逆推一下,就会发现,因为$ f_n=f_{n-1}+f_{n-2}$,所以我们构造时$ f_n$的位置应该使它成为$ f_{n-1}+f_{n-2}$的结果,同理,$ f{n-1}$的位置(当前第二列)让它直接换成递推过来的矩阵中第一列。于是我们的矩阵就是这样
$ [ f(n) \ f(n-1) ] = [ f(n-1) \ f(n-2) ] \times \begin{bmatrix}1 & 1 \\ 1 & 0\end{bmatrix}$
对于此时的第一行第一列f(n),是由等式右边左侧矩阵的第一列乘上右侧矩阵的第一行,就是f(n-1)+f(n-2)。
e.g.
再来看洛谷上的模板题
a[1]=a[2]=a[3]=1 a[x]=a[x-3]+a[x-1] (x>3) 求a数列的第n项对1000000007(10^9+7)取余的值。
这样的递推式不是连续的,但我们可以把它构造成$ [ f(n) \ f(n-1) \ f(n-2) ] = [ f(n-1) \ f(n-2) \ f(n-3) ] \times A$,根据a[x]=a[x-3]+a[x-1]求出A即可,我们推算出$ A=\begin{bmatrix}1 & 1 & 0\\ 0 & 0 & 1\\ 1 & 0 & 0\end{bmatrix}$,这就是矩阵加速递推的应用了。
斐波那契数列Code:
#include<cstdio>
#include<cstring>
const int p=1e9+7;
struct matrix
{
int x,y;
long long a[5][5];
matrix(int x,int y)
{
this->x=x;
this->y=y;
memset(a,0,sizeof(a));
}
matrix()
{
memset(a,0,sizeof(a));
}
matrix cross(matrix b)//矩乘
{
matrix m(x,b.y);
for(int i=1;i<=x;i++)
for(int j=1;j<=b.y;j++)
{
for(int k=1;k<=y;k++)
m.a[i][j]+=a[i][k]*b.a[k][j]%p;
m.a[i][j]%p;
}
return m;
}
void square()//自乘
{
matrix m(x,y);
for(int i=1;i<=x;i++)
for(int j=1;j<=x;j++)
{
for(int k=1;k<=x;k++)
m.a[i][j]+=a[i][k]*a[k][j]%p;
m.a[i][j]%=p;
}
*this=m;
}
matrix qpow(long long t)//快速幂
{
matrix m=*this;
matrix ans(x,y);
for(int i=1;i<=x;i++)
ans.a[i][i]=1;
while(t)
{
if(t&1)
ans=ans.cross(m);
m.square();
t>>=1;
}
return ans;
}
};
int main()
{
long long N;
scanf("%lld",&N);
matrix m(2,2);//初始化
m.a[1][1]=1;
m.a[1][2]=1;
m.a[2][1]=1;
matrix n(1,2);
n.a[1][1]=1;
n.a[1][2]=1;
m=m.qpow(N-1);
m=m.cross(n);
printf("%lld\n",m.a[1][1]);
return 0;
}
矩阵加速数列Code:
#include<cstdio>
#include<cstring>
const int p=1e9+7;
struct matrix
{
long long a[5][5];
int x,y;
matrix(int x,int y)
{
this->x=x;
this->y=y;
memset(a,0,sizeof(a));
}
matrix()
{
memset(a,0,sizeof(a));
}
matrix cross(matrix a,matrix b)//矩乘
{
matrix m(a.x,b.y);
for(int i=1;i<=m.x;i++)
for(int j=1;j<=m.y;j++)
{
for(int k=1;k<=a.y;k++)
m.a[i][j]+=a.a[i][k]*b.a[k][j]%p;
m.a[i][j]%=p;
}
return m;
}
void square()//自乘
{
matrix m(x,y);
for(int i=1;i<=x;i++)
for(int j=1;j<=y;j++)
{
for(int k=1;k<=x;k++)
m.a[i][j]+=a[i][k]*a[k][j]%p;
m.a[i][j]%=p;
}
*this=m;
}
matrix qpow(long long t)//快速幂
{
matrix m=*this;
matrix ans(x,y);
for(int i=1;i<=x;i++)
ans.a[i][i]=1;
while(t)
{
if(t&1)
ans=cross(ans,m);
m.square();
t>>=1;
}
return ans;
}
};
int main()
{
int t,u;
scanf("%d",&t);
matrix m(3,3);
matrix n(1,3);//初始化
m.a[1][1]=1;
m.a[1][2]=1;
m.a[2][3]=1;
m.a[3][1]=1;
n.a[1][1]=1;
n.a[1][2]=1;
n.a[1][3]=1;
for(int i=1;i<=t;i++)
{
scanf("%d",&u);
if(u<=3)
{
puts("1");
continue;
}
matrix ans=m.qpow(u-3);
ans=ans.cross(n,ans);
printf("%d\n",ans.a[1][1]);
}
return 0;
}
说点什么