两个对象矩阵相乘的有效方法

2023-12-19

作为程序的一部分,我需要将两个二维矩阵相乘。这些矩阵是创建的 Matrix 类的一部分。我现在的代码运行良好,但我想知道是否有更有效的方法将这些矩阵相乘。

public Matrix multiply(Matrix matrix) {
    //returns 2D array of Matrix matrix object
    int[][] userMatrix = matrix.getMatrix();
    //int [][] for the multiplied matrix
    int[][] multiplied = new int[length][length];
    int[] tempA = new int[length];
    int[] tempB = new int[length];

    int sum = 0;
    for (int row = 0; row < length; row++) {
        for (int col = 0; col < length; col++) {
            tempA[col] = arrayObject[row][col];
        }
        for (int j = 0; j < length; j++) {
            for (int i = 0; i < length; i++) {
                tempB[i] = userMatrix[i][j];
            }
            for (int k = 0; k < length; k++) {
                sum += tempA[k] * tempB[k];
            }
            multiplied[row][j] = sum;
            sum = 0;
        }
    }

    //converts the int[][] to a Matrix object
    Matrix returnMatrix = new Matrix(multiplied, multiplied.length);

    return returnMatrix;
}

使用类似施特拉森的方案进行乘法。本质上,您将矩阵分解为四个子矩阵并计算一些中间值,然后根据这些较小的中间值计算解决方案。

Schema:

matrix split

现在,不用计算

C_11=A_11·B_11+A_12·B_21
C_12=A_11·B_12+A_12·B_22
C_21=A_21·B_11+A_22·B_21
C_22=A_21·B_12+A_22·B_22

你算算中介

M_1 = (A_11+A_22)·(B_11+B_22)
M_2 = (A_21+A_22)·B_11
M_3 = A_11·(B_12-B_22)
M_4 = A_22·(B_21-B_11)
M_5 = (A_11+A_12)·B_22
M_6 = (A_21-A_11)·(B_11+B_12)
M_7 = (A_12-A_22)·(B_21+B_22)

并得到解决方案

C_11 = M_1+M_4-M_5+M_7
C_12 = M_3+M_5
C_21 = M_2+M_4
C_22 = M_1-M_2+M_3+M_6

继续递归地执行此操作,您将只需要 O(n^log_2(7)) ~=O(n^2.807) 次乘法(加上一些加法和减法),而不是您使用的经典 O(n^3) 变体。对于实际实现,您需要进行试验,直到找到切换到经典变体的良好截止点。

至于代码。尝试以下操作one https://github.com/freddygv/strassen。 (注意,我并不声称它有效,它只是我发现的唯一一个附有适当许可证的软件(GPL 3.0))

另外,一个很大的警告:我发现的几乎所有代码都隐式地假设矩阵是二的幂,以便分割步骤是无缝的,直到达到基本情况。您可能需要添加一些逻辑来处理其他拆分(或恢复到基本实现)。

一般来说,您应该为此使用一个库,从而避免实施+测试的痛苦。

import java.io.BufferedReader;
import java.io.FileReader;

public class Strassen {

    static final int STRASSEN_MULT = 0;
    static final int STANDARD_MULT = 1;
    static final int CROSSOVER = 64;

    static int MULT_MODE = STRASSEN_MULT;

    static boolean DEBUGGING = false;
    static boolean CROSSOVER_TEST = false;

    /*
        Formula:
        [[ A B ]     [[ E F ]      [[AE + BG   AF + BH]
         [ C D ]]  *  [ G H ]]  =   [CE + DG   CF + DH]]
     */

    public static int[][] strassenWithCrossover(int[][] X, int[][] Y, int crossover) {
        int[][] ret = new int[X.length][X.length];
        if (X.length <= crossover) {
            ret = standardMult(X, Y);
            return ret;
        }

        int n = X.length;

        int[][] A = getSubMatrix(X, 0, 0);
        int[][] D = getSubMatrix(X, n / 2, n / 2);

        int[][] E = getSubMatrix(Y, 0, 0);
        int[][] H = getSubMatrix(Y, n / 2, n / 2);

        int[][] P1 = strassenWithCrossover(A, subtract(Y, 0, n / 2, Y, n / 2, n / 2), crossover);
        int[][] P2 = strassenWithCrossover(add(X, 0, 0, X, 0, n / 2), H, crossover);
        int[][] P3 = strassenWithCrossover(add(X, n / 2, 0, X, n / 2, n / 2), E, crossover);
        int[][] P4 = strassenWithCrossover(D, subtract(Y, n / 2, 0, Y, 0, 0), crossover);
        int[][] P5 = strassenWithCrossover(add(X, 0, 0, X, n / 2, n / 2), add(Y, 0, 0, Y, n / 2, n / 2), crossover);
        int[][] P6 = strassenWithCrossover(subtract(X, 0, n / 2, X, n / 2, n / 2), add(Y, n / 2, 0, Y, n / 2, n / 2), crossover);
        int[][] P7 = strassenWithCrossover(subtract(X, 0, 0, X, n / 2, 0), add(Y, 0, 0, Y, 0, n / 2), crossover);

        int[][] AE_plus_BG = subtract(add(P5, P4), subtract(P2, P6));
        int[][] AF_plus_BH = add(P1, P2);
        int[][] CE_plus_DG = add(P3, P4);
        int[][] CF_plus_DH = subtract(add(P5, P1), add(P3, P7));

        assignSubMatrix(ret, 0, 0, AE_plus_BG);
        assignSubMatrix(ret, 0, n / 2, AF_plus_BH);
        assignSubMatrix(ret, n / 2, 0, CE_plus_DG);
        assignSubMatrix(ret, n / 2, n / 2, CF_plus_DH);

        return ret;

    }

    private static int[][] getSubMatrix(int[][] matrix, int rowStart, int colStart) {

        int[][] ret = new int[matrix.length / 2][matrix.length / 2];
        int i = rowStart;
        for (int row = 0; row < matrix.length / 2; row++) {
            int j = colStart;
            for (int col = 0; col < (matrix.length / 2); col++) {
                ret[row][col] = matrix[i][j];
                j++;
            }
            i++;
        }
        return ret;
    }

    private static void assignSubMatrix(int[][] matrix, int rowStart, int colStart, int[][] sub) {

        int i = rowStart;
        int j;
        for (int row = 0; row < matrix.length / 2; row++) {
            j = colStart;
            for (int col = 0; col < matrix.length / 2; col++) {
                matrix[i][j] = sub[row][col];
                j++;
            }
            i++;
        }
    }

    private static int[][] add(int[][] X, int[][] Y) {

        int[][] ret = new int[X.length][X.length];
        for (int row = 0; row < ret.length; row++) {
            for (int col = 0; col < ret.length; col++) {
                ret[row][col] = X[row][col] + Y[row][col];
            }
        }

        return ret;
    }

    private static int[][] add(int[][] X, int X_row_start, int X_col_start, int[][] Y, int Y_row_start, int Y_col_start) {

        int length = X.length / 2;
        int[][] ret = new int[length][length];
        for (int row = 0; row < length; row++) {
            for (int col = 0; col < length; col++) {
                ret[row][col] = X[X_row_start + row][X_col_start + col] + Y[Y_row_start + row][Y_col_start + col];
            }
        }

        return ret;
    }

    private static int[][] subtract(int[][] X, int[][] Y) {

        int[][] ret = new int[X.length][X.length];
        for (int row = 0; row < ret.length; row++) {
            for (int col = 0; col < ret.length; col++) {
                ret[row][col] = X[row][col] - Y[row][col];
            }
        }

        return ret;

    }

    private static int[][] subtract(int[][] X, int X_row_start, int X_col_start, int[][] Y, int Y_row_start, int Y_col_start) {

        int length = X.length / 2;
        int[][] ret = new int[length][length];
        for (int row = 0; row < length; row++) {
            for (int col = 0; col < length; col++) {
                ret[row][col] = X[X_row_start + row][X_col_start + col] - Y[Y_row_start + row][Y_col_start + col];
            }
        }

        return ret;

    }

    public static void main(String[] args) {

        if (args.length != 3) {
            System.out.println("Usage: ./strassen 0 dimension inputfile");
            System.exit(1);
        }

        int flag = Integer.parseInt(args[0]);
        int dimension = Integer.parseInt(args[1]);
        String inputfile = new String(args[2]);

        if (flag == 1) {
            DEBUGGING = true;
        } else if (flag == 2) {
            CROSSOVER_TEST = true;
        }

        Strassen me = new Strassen();

        if (CROSSOVER_TEST) {
            for (int i = 1 << 7; i < 1 << 16; i *= 2) {
                me.run(i, inputfile, MULT_MODE);
            }

        } else {
            me.run(dimension, inputfile, MULT_MODE);
        }

    }


    public void run(int dimension, String inputfile, int mode) {

        long startTime;

        int[][] X = new int[dimension][dimension];
        int[][] Y = new int[dimension][dimension];

        int[] elements = {0, 1, 2, 0, 2, 1, 1, 0, 2, 1, 2, 0, 2, 1, 0, 2, 0, 1};
        int pos = 0;

        try {
            BufferedReader br = new BufferedReader(new FileReader(inputfile));

            for (int i = 0; i < dimension; i++) {
                for (int j = 0; j < dimension; j++) {
                    if (CROSSOVER_TEST) {
                        X[i][j] = elements[pos++];
                        pos %= elements.length;

                    } else {
                        X[i][j] = Integer.parseInt(br.readLine());

                    }
                }
            }
            for (int k = 0; k < dimension; k++) {
                for (int l = 0; l < dimension; l++) {
                    if (CROSSOVER_TEST) {
                        Y[k][l] = elements[pos++];
                        pos %= elements.length;

                    } else {
                        Y[k][l] = Integer.parseInt(br.readLine());

                    }
                }
            }

            br.close();

        } catch (Exception e) {
            System.err.println("Caught Exception: " + e.getMessage());
        }

        if (DEBUGGING) {
            System.out.println("\n##### Reading Matrices X and Y from file ######\n");
            printMatrix(X,"X");
            printMatrix(Y,"Y");
        }

        if (mode == STANDARD_MULT) {
            int[][] Z = standardMult(X, Y);

            if (DEBUGGING) {
                System.out.println("Standard Product");
                printMatrix(Z, "Z");
            }

        } else if (mode == STRASSEN_MULT && CROSSOVER_TEST) {

            for (int crossover = 2; crossover <= dimension; crossover *= 2) {
                startTime = System.currentTimeMillis();

                int[][] paddedX = pad(X);
                int[][] paddedY = pad(Y);

                int[][] Z = strassenWithCrossover(paddedX, paddedY, crossover);

                printTimes("Strassen Product", startTime, dimension, crossover);
            }

        } else if (mode == STRASSEN_MULT) {
            startTime = System.currentTimeMillis();

            int[][] paddedX = pad(X);
            int[][] paddedY = pad(Y);

            int[][] Z = strassenWithCrossover(paddedX, paddedY, CROSSOVER);

            int[][] ZTrimmed = trim(Z, dimension);

            if (DEBUGGING) {
                printTimes("Strassen Product", startTime, dimension, CROSSOVER);
                printMatrix(ZTrimmed, "Z");

            } else {
                printDiagonal(ZTrimmed);
            }
        }


    }

    private static void printTimes(String mode, long startTime, int dimension, int crossover) {
        System.out.println(mode + " Crossover = " + crossover);
        long time = System.currentTimeMillis() - startTime;
        System.out.printf("Finished Matrix Multiplication of %d dimensions in %d milliseconds, or %.2f minutes\n", dimension, time, ((double) time) / 60 / 1000);
        System.out.println();

    }

    private static int[][] pad(int[][] matrix) {

        int newDim = nextPowerOf2(matrix.length);
        if (newDim == matrix.length)
            return matrix;
        int[][] ret = new int[newDim][newDim];

        for (int row = 0; row < matrix.length; row++) {
            for (int col = 0; col < matrix.length; col++) {
                ret[row][col] = matrix[row][col];
            }
        }
        return ret;

    }

    private static int[][] trim(int[][] matrix, int dim) {

        int[][] ret = new int[dim][dim];
        for (int row = 0; row < dim; row++) {
            for (int col = 0; col < dim; col++) {
                ret[row][col] = matrix[row][col];
            }
        }

        return ret;
    }

    private static int nextPowerOf2(int length) {
        int exponent = (int) (Math.log(length) / Math.log(2));
        int reconstructed = (int) Math.pow(2, exponent);
        if (length != reconstructed) {
            return (int) Math.pow(2, exponent + 1);
        }
        return length;
    }

    // Standard Matrix Multiplication
    public static int[][] standardMult(int[][] A, int[][] B) {

        int dim = B.length;
        int[][] C = new int[B.length][B.length];
        for (int i = 0; i < dim; i++) {
            for (int j = 0; j < dim; j++) {
                for (int k = 0; k < dim; k++) {
                    if (DEBUGGING) {
                        System.out.println(C[i][k] + " += " + A[i][k] + " * " + B[k][j]);
                    }

                    C[i][j] += A[i][k] * B[k][j];
                }
            }
        }
        return C;
    }

    // Prints complete matrix
    public static void printMatrix(int[][] A) {
        int dim = A.length;

        for (int i = 0; i < dim; i++) {
            System.out.print(" [ ");

            for (int j = 0; j < dim; j++) {
                System.out.print(A[i][j] + " ");
            }
            System.out.println("]");
        }
        System.out.println();
    }

    // Prints complete matrix
    public static void printMatrix(int[][] A, String name) {
        System.out.println("Printing matrix " + name);
        printMatrix(A);
    }

    // Prints the list of values on the diagonal entries
    public static void printDiagonal(int[][] A) {
        for (int i = 0; i < A.length; i++) {
            System.out.println(A[i][i]);
        }
    }

}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

两个对象矩阵相乘的有效方法 的相关文章

随机推荐