解题思路:
注意事项:
参考代码:
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
typedef unsigned long long LL;
LL n, m, mod;
class M {
public:
LL data[2][2];
M() { memset(data, 0, sizeof(data)); }
};
void solve1() {
LL a = 1;
LL b = 1;
if (m >= n + 2) {
for (LL i = 3; i <= n + 2; ++i) {
LL t = a;
a = b;
b += t;
}
printf("%llu\n", b % mod - 1);
} else {//m<n+2
LL fibM, fibN_2 = 0;
for (LL i = 3; i <= n + 2; ++i) {
LL t = a;
a = b;
b += t;
if (i == m) fibM = b;
}
fibN_2 = b;
printf("%llu %llu\n", fibN_2, fibN_2 % fibM % mod - 1);
}
}
M *mul(M *m1, M *m2) {
M *ans = new M();
ans->data[0][0] = m1->data[0][0] * m2->data[0][0] + m1->data[0][1] * m2->data[1][0];
ans->data[0][1] = m1->data[0][0] * m2->data[0][1] + m1->data[0][1] * m2->data[1][1];
ans->data[1][0] = m1->data[1][0] * m2->data[0][0] + m1->data[1][1] * m2->data[1][0];
ans->data[1][1] = m1->data[1][0] * m2->data[0][1] + m1->data[1][1] * m2->data[1][1];
return ans;
}
LL mm(LL a, LL b, LL mod) {
if (a > b) {
LL t = a;
a = b;
b = t;
}
LL x = 0;
while (b != 0) {
if ((b & 1) == 1) {
x = (x + a) % mod;
}
a = (a * 2) % mod;
b >>= 1;
}
return x;
}
M *mul(M *m1, M *m2, LL mod) {
M *ans = new M();
ans->data[0][0] = (mm(m1->data[0][0], m2->data[0][0], mod) + mm(m1->data[0][1], m2->data[1][0], mod)) % mod;
ans->data[0][1] = (mm(m1->data[0][0], m2->data[0][1], mod) + mm(m1->data[0][1], m2->data[1][1], mod)) % mod;
ans->data[1][0] = (mm(m1->data[1][0], m2->data[0][0], mod) + mm(m1->data[1][1], m2->data[1][0], mod)) % mod;
ans->data[1][1] = (mm(m1->data[1][0], m2->data[0][1], mod) + mm(m1->data[1][1], m2->data[1][1], mod)) % mod;
return ans;
}
//log(n)
M *mPow(M *m, LL n) {
M *E = new M();
E->data[0][0] = 1;
E->data[1][1] = 1;
while (n != 0) {
if (n & 1 == 1) {
E = mul(E, m);
}
m = mul(m, m);
n >>= 1;
}
return E;
}
//log(n)
M *mPow(M *m, LL n, LL mod) {
M *E = new M();
E->data[0][0] = 1;
E->data[1][1] = 1;
while (n != 0) {
if ((n & 1) == 1) {
E = mul(E, m, mod);
}
m = mul(m, m, mod);
n >>= 1;
}
return E;
}
LL fib(LL i) {
//[1,1]B^(i-2)
M *A = new M();
A->data[0][0] = 1;
A->data[0][1] = 1;
M *B = new M();
B->data[0][0] = 1;
B->data[0][1] = 1;
B->data[1][0] = 1;
M *ans = mul(A, mPow(B, i - 2));
return ans->data[0][0];
}
LL fib(LL i, LL mod) {
//[1,1]B^(i-2)
M *A = new M();
A->data[0][0] = 1;
A->data[0][1] = 1;
M *B = new M();
B->data[0][0] = 1;
B->data[0][1] = 1;
B->data[1][0] = 1;
M *ans = mul(A, mPow(B, i - 2, mod), mod);
return ans->data[0][0];
}
void solve2() {
if (m >= n + 2) {
printf("%llu\n", fib(n + 2, mod) - 1);
} else {//m<n+2
LL fibm = fib(m);
printf("%llu\n", fib(n + 2, fibm) % mod - 1);
}
}
int main(int argc, const char *argv[])
{
scanf("%llu %llu %llu", &n, &m, &mod);
solve2();
return 0;
}
0.0分
6 人评分