「题目链接」

点击打开链接

「题目概括」

给定一个$n\times m$的矩阵,每个格子内有$a(i,j)$个互不相同的元素。
每行最多选择一个元素,且必须满足所有元素中不存在有超过一半的元素分布在同一个列。
数据范围:$n\leq 100, m\leq 2000$。

「思路要点」

考虑用总方案数-不合法方案数。
总方案数=$(\Pi sum_i+1)-1$
因为只有一列会被判定为不合法,所以我们就枚举当前$w$列不合法。
定义状态为$f(i,j,k)$表示前$i$行中$w$列选择了$j$次,其他列选择了$k$次。
状态转移方程为$$f(i,j,k)=f(i-1,j,k)+a(i,w)\times f(i-1,j-1,k)+(sum(i)-a(i,w))\times f(i-1,j,k-1)$$
时间复杂度为$\mathcal O(m\times n^3)$。
考虑优化,发现我们只关心差值,所以重新定义状态为$f(i,d)$表示前$i$行,上述$j-k$为$d$。
状态转移方程为$$f(i,d)=f(i-1,d)+a(i,w)\times f(i,d-1)+(sum(i)-a(i,w))\times f(i,d+1)$$
时间复杂度为$\mathcal O(m\times n^2)$。

「代码」

// 84pts
// time complex O(m*n^3)
#include <bits/stdc++.h>

using namespace std;

template <class T>
void re(T& x) {
    x = 0; char ch = 0; int f = 1;
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
    for (; isdigit(ch); ch = getchar()) x = x * 10 + (ch ^ 48);
    x *= f;
}

typedef long long ll;

const int MOD = 998244353;
const int MAXN = 105;
const int MAXM = 2005;

ll sum[MAXN];
int a[MAXN][MAXM];
int f[MAXN][MAXN][MAXN];
int n, m;

void add(int& x, int y) {
    x += y;
    if (x >= MOD) x -= MOD;
}
void dec(int& x, int y) {
    x -= y;
    if (x < 0) x += MOD;
}

int main() {
    freopen("meal.in", "r", stdin), freopen("meal.out", "w", stdout);
    re(n), re(m);
    for (int i = 1; i <= n; i++) {
        sum[i] = 0ll;
        for (int j = 1; j <= m; j++) re(a[i][j]), sum[i] += a[i][j];
    }
    int ans = 1;
    for (int i = 1; i <= n; i++) ans = (ll)ans * (sum[i] + 1) % MOD;
    dec(ans, 1);
    for (int w = 1; w <= m; w++) {
        memset(f, 0, sizeof f);
        f[0][0][0] = 1;
        for (int i = 1; i <= n; i++) {
            for (int j = 0; j <= n; j++) {
                for (int k = 0; k <= n; k++) {
                    f[i][j][k] = f[i - 1][j][k];
                    if (j) add(f[i][j][k], (ll)a[i][w] * f[i - 1][j - 1][k] % MOD);
                    if (k) add(f[i][j][k], (sum[i] - a[i][w]) * f[i - 1][j][k - 1] % MOD);
                }
            }
        }
        for (int j = 0; j <= n; j++) 
            for (int k = 0; k < j; k++) 
                dec(ans, f[n][j][k]);
    }
    printf("%d\n", ans);
    return 0; 
}
// 100pts
// time complex O(m*n^2)
#include <bits/stdc++.h>

using namespace std;

template <class T>
void re(T& x) {
    x = 0; char ch = 0; int f = 1;
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
    for (; isdigit(ch); ch = getchar()) x = x * 10 + (ch ^ 48);
    x *= f;
}

typedef long long ll;

const int MOD = 998244353;
const int MAXN = 105;
const int MAXM = 2005;
const int SN = 101;

int sum[MAXN];
int a[MAXN][MAXM];
int f[MAXN][MAXN * 3];
int n, m;

void add(int& x, int y) {
    x += y;
    if (x >= MOD) x -= MOD;
}
void dec(int& x, int y) {
    x -= y;
    if (x < 0) x += MOD;
}

int main() {
    freopen("meal.in", "r", stdin), freopen("meal.out", "w", stdout);
    re(n), re(m);
    for (int i = 1; i <= n; i++) {
        sum[i] = 0ll;
        for (int j = 1; j <= m; j++) re(a[i][j]), add(sum[i], a[i][j]);
    }
    int ans = 1;
    for (int i = 1; i <= n; i++) ans = (ll)ans * (sum[i] + 1) % MOD;
    dec(ans, 1);
    for (int w = 1; w <= m; w++) {
        memset(f, 0, sizeof f);
        f[0][SN] = 1;
        for (int i = 1; i <= n; i++) {
            for (int d = SN - n; d <= SN + n; d++) {
                f[i][d] = f[i - 1][d];
                add(f[i][d], (ll)f[i - 1][d - 1] * a[i][w] % MOD);
                add(f[i][d], (ll)f[i - 1][d + 1] * (sum[i] - a[i][w] + MOD) % MOD);
            }
        }
        for (int j = SN + 1; j <= SN + n; j++) dec(ans, f[n][j]); 
    }
    printf("%d\n", ans);
    return 0; 
}
Last modification:December 15th, 2019 at 01:04 pm
如果觉得我的文章对你有用,请随意赞赏