






测试代码如下:from ctypes import *

import scipy.sparse as spsp

import numpy as np

import multiprocessing as mp

# Load the share library

mkl = cdll.LoadLibrary("libmkl_rt.so")

def get_csr_handle2(data, indices, indptr, shape):

a_pointer = data.ctypes.data_as(POINTER(c_float))

ja_pointer = indices.ctypes.data_as(POINTER(c_int))

ia_pointer = indptr.ctypes.data_as(POINTER(c_int))

return (a_pointer, ja_pointer, ia_pointer, shape)

def get_csr_handle(A,clear=False):

if clear == True:

A.indptr[:] = 0

A.indices[:] = 0

A.data[:] = 0

return get_csr_handle2(A.data, A.indices, A.indptr, A.shape)

def csr_t_dot_csr(A_handle, C_handle, nz=None):

# Calculate (A.T).dot(A) and put result into C


# This uses one-based indexing


# Both C.data and A.data must be in np.float32 type.


# Number of nonzero elements in C must be greater than

# or equal to the size of C.data


# size of C.indptr must be greater than or equal to

# 1 + (num rows of A).


# C_data = np.zeros((nz), dtype=np.single)

# C_indices = np.zeros((nz), dtype=np.int32)

# C_indptr = np.zeros((m+1),dtype=np.int32)

(a_pointer, ja_pointer, ia_pointer, A_shape) = A_handle

(c_pointer, jc_pointer, ic_pointer, C_shape) = C_handle

trans_pointer = byref(c_char('T'))

sort_pointer = byref(c_int(0))

(m, n) = A_shape

sort_pointer = byref(c_int(0))

m_pointer = byref(c_int(m)) # Number of rows of matrix A

n_pointer = byref(c_int(n)) # Number of columns of matrix A

k_pointer = byref(c_int(n)) # Number of columns of matrix B

# should be n when trans='T'

# Otherwise, I guess should be m


b_pointer = a_pointer

jb_pointer = ja_pointer

ib_pointer = ia_pointer


if nz == None:

nz = n*n #*n # m*m # Number of nonzero elements expected

# probably can use lower value for sparse

# matrices.

nzmax_pointer = byref(c_int(nz))

# length of arrays c and jc. (which are data and

# indices of csr_matrix). So this is the number of

# nonzero elements of matrix C


# This parameter is used only if request=0.

# The routine stops calculation if the number of

# elements in the result matrix C exceeds the

# specified value of nzmax.

info = c_int(-3)

info_pointer = byref(info)

request_pointer_list = [byref(c_int(0)), byref(c_int(1)), byref(c_int(2))]

return_list = []

for ii in [0]:

request_pointer = request_pointer_list[ii]

ret = mkl.mkl_scsrmultcsr(trans_pointer, request_pointer, sort_pointer,

m_pointer, n_pointer, k_pointer,

a_pointer, ja_pointer, ia_pointer,

b_pointer, jb_pointer, ib_pointer,

c_pointer, jc_pointer, ic_pointer,

nzmax_pointer, info_pointer)

info_val = info.value

return_list += [ (ret,info_val) ]

return return_list

def test():

num_cpu = 12

mkl.mkl_set_num_threads(byref(c_int(num_cpu))) # try to set number of mkl threads

print "mkl get max thread:", mkl.mkl_get_max_threads()


def test_csr_t_dot_csr():

AA = np.random.choice([0,1], size=(12,750000), replace=True, p=[0.99,0.01])

A_original = spsp.csr_matrix(AA)

A = A_original.astype(np.float32).tocsc()

A = spsp.csr_matrix( (A.data, A.indices, A.indptr) )

A.indptr += 1 # convert to 1-based indexing

A.indices += 1 # convert to 1-based indexing

A_ptrs = get_csr_handle(A)

C = spsp.csr_matrix( np.ones((12,12)), dtype=np.float32)

C_ptrs = get_csr_handle(C, clear=True)

print "=call mkl function="

while (True):

return_list = csr_t_dot_csr(A_ptrs, C_ptrs)

if __name__ == "__main__":


