import time
import numpy as np
from numba import cuda,float32


@cuda.jit
# define function for 1D array
def NasobeniMaticeCuda(A,B,C):    
    i,j=cuda.grid(2)    
    
    if i < C.shape[0] and j < C.shape[1]:
        tmp=0.0
        for k in range(A.shape[1]):
            tmp += A[i,k]*B[k,j]
        C[i,j] = tmp



# linearizace matice
@cuda.jit
def NasobeniMaticeCuda2(A,B,C):
    #assume that matrices have same shape
    TPB = ThreadsPerBlock
    BLKPG = blocksPerGrid
    sA = cuda.shared.array(shape=(TPB,TPB),dtype=float32)
    sB = cuda.shared.array(shape=(TPB,TPB),dtype=float32)
    
    dim = C.shape
    x, y = cuda.grid(2)    
    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    
    if x >= dim[0] and y >= dim[1]:
        return

    tmp=0.    
    for i in range(BLKPG):
        sA[tx,ty] = A[x,ty+i*TPB]
        sB[tx,ty] = B[tx+i*TPB,y]
        
        cuda.syncthreads()
        
        for j in range(TPB):
            tmp += sA[tx,j]*sB[j,ty]
            
        cuda.syncthreads()

    C[x,y]=tmp


N=int(1000)
blocksPerGrid=150
ThreadsPerBlock=int(N/blocksPerGrid)
#A(m,n)
A = np.random.randint(low=1, high=5,size=(N,N)) 
#B(n,p)
B = np.random.randint(low=1, high=5,size=(N,N))
#C(m,p)

C = np.zeros((np.size(A,0),np.size(B,1)))
#------------------------------------------------------------------------------
tic = time.time()
C = np.matmul(A,B)
toc = time.time()
print('elapsed time matmul function',toc-tic)
#------------------------------------------------------------------------------

#------------------------------------------------------------------------------


Cgpu = np.zeros((N,N))

tic = time.time()
Ad = cuda.to_device(A)
Bd = cuda.to_device(B)
Cd = cuda.to_device(Cgpu)
#NasobeniMaticeCuda[(blocksPerGrid,blocksPerGrid),(ThreadsPerBlock,ThreadsPerBlock)](Ad,Bd,Cd)
NasobeniMaticeCuda2[(blocksPerGrid,blocksPerGrid),(ThreadsPerBlock,ThreadsPerBlock)](Ad,Bd,Cd)
Cgpu = Cd.copy_to_host()
toc = time.time()
print('elapsed time CUDA function',toc-tic)
#------------------------------------------------------------------------------

print(sum(sum(Cgpu-C)))

