//#include "mini_blas.h" void cpu_gemm_nn(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, float BETA, float *C, int ldc) { int i,j,k; for(i = 0; i < M; ++i){ for(k = 0; k < K; ++k){ register float A_PART = ALPHA*A[i*lda+k]; for(j = 0; j < N; ++j){ C[i*ldc+j] += A_PART*B[k*ldb+j]; } } } } void cpu_gemm_nt(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, float BETA, float *C, int ldc) { int i,j,k; for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ register float sum = 0; for(k = 0; k < K; ++k){ sum += ALPHA*A[i*lda+k]*B[k+j*ldb]; } C[i*ldc+j] += sum; } } } void cpu_gemm_tn(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, float BETA, float *C, int ldc) { int i,j,k; for(i = 0; i < M; ++i){ for(k = 0; k < K; ++k){ register float A_PART = ALPHA*A[k*lda+i]; for(j = 0; j < N; ++j){ C[i*ldc+j] += A_PART*B[k*ldb+j]; } } } } void cpu_gemm_tt(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, float BETA, float *C, int ldc) { int i,j,k; for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ for(k = 0; k < K; ++k){ C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb]; } } } } void cpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA, float *A, int lda, float *B, int ldb, float BETA, float *C, int ldc) { int i, j; for(i = 0; i < M; ++i){ for(j = 0; j < N; ++j){ C[i*ldc + j] *= BETA; } } if(!TA && !TB) cpu_gemm_nn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); else if(TA && !TB) cpu_gemm_tn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); else if(!TA && TB) cpu_gemm_nt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); else cpu_gemm_tt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc); }