C - Palindromic in Both Bases
观察数据范围 n <= 1e12, 暴力的话不可行。可以通过枚举一个回文串的前半部分, 发现最多枚举到1e6就可以了,暴力判断这个数在A进制下是否是回文数即可
#include <bits/stdc++.h>
#define int long long
bool check(int v, int base) {
std::string s;
while(v > 0) {
int d = v % base;
s += d + '0';
v /= base;
}
std::string t = s;
reverse(s.begin(), s.end());
return s == t;
}
void solve() {
int a, n;
std::cin >> a >> n;
int ans = 0;
std::string len = std::to_string(n);
for(int i = 1; i <= n; i++) {
std::string t = std::to_string(i);
// 这里判断数字长度超过了, 后续肯定都不符合了, 退出即可
if(t.size() * 2 - 1 > len.size()) break;
std::string s = t;
reverse(s.begin(), s.end());
std::string c1 = t + s, c2 = t + s.substr(1);
int x = stoll(c1), y = stoll(c2);
if(x <= n && check(x, a)) ans += x;
if(y <= n && check(y, a)) ans += y;
}
std::cout << ans << '\n';
}
signed main() {
int t = 1;
// std::cin >> t;
while(t--) solve();
return 0;
}
D - Transmission Mission
观察题意可知, 可以把问题转化为 :
给数组排序, 把每两个相邻的数字的差加起来。然后把最大的(m - 1)个差减去, 就是所求的结果
#include <bits/stdc++.h>
#define int long long
void solve() {
int n, m;
std::cin >> n >> m;
std::vector<int> a(n);
for (int i = 0; i < n; i++) {
std::cin >> a[i];
}
if(n == m) {
std::cout << 0 << '\n';
return;
}
sort(a.begin(), a.end());
int ans = 0;
std::priority_queue<int, std::vector<int>, std::less<int>> pq;
for (int i = 1; i < n; i++) {
ans += a[i] - a[i - 1];
pq.push(a[i] - a[i - 1]);
}
m--;
while(m--) {
ans -= pq.top();
pq.pop();
}
std::cout << ans << '\n';
}
signed main() {
int t = 1;
// std::cin >> t;
while(t--) solve();
return 0;
}
E - Count A%B=C
观察题意, a % b == c , && a != b != c, 可以发现, a 一定大于 b,并且 a 不能整除 b, 设b = x, a ∈ [x+1, n], 合法取值个数需要去掉 a % b == 0 的情况, 所以 x 对结果的贡献就是 n - x - floor((n - x) / x) = n - x + 1 - floor(n / x).
结果就是下面
可以转化为
第二部分暴力计算会超时
优化思路
观察 n / i 的值,可以发现:
对于许多连续的 i,n / i 的值是相同的。
我们可以找到这些连续区间,并一次性计算它们的贡献,而不是逐个计算。
详细见代码
#include <bits/stdc++.h>
#define int long long
const int MOD = 998244353;
int qui(int a, int b) {
int res = 1;
while(b) {
if(b & 1) res = res * a % MOD;
b >>= 1;
a = a * a % MOD;
}
return res;
}
void solve() {
int n;
std::cin >> n;
int ans = ((n % MOD) * ((n + 1) % MOD)) % MOD;
int t = (n % MOD) * ((n + 1) % MOD) % MOD;
t = t * (qui(2, MOD - 2)) % MOD;
ans = (ans - t + MOD) % MOD;
for (long long l = 1, r; l <= n; l = r + 1) {
r = n / (n / l);
long long cnt = r - l + 1;
ans = (ans - (n / l) * cnt % MOD + MOD) % MOD;
}
std::cout << ans << '\n';
}
signed main() {
int t = 1;
// std::cin >> t;
while(t--) solve();
return 0;
}