本文整理匯總了Python中torch.diag方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.diag方法的具體用法?Python torch.diag怎麽用?Python torch.diag使用的例子?那麽恭喜您, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在模塊torch的用法示例。


示例1: fuse_conv_and_bn

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def fuse_conv_and_bn(conv, bn):

# https://tehnokv.com/posts/fusing-batchnorm-and-conv/

with torch.no_grad():

# init

fusedconv = torch.nn.Conv2d(conv.in_channels,






# prepare filters

w_conv = conv.weight.clone().view(conv.out_channels, -1)

w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))

fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

# prepare spatial bias

if conv.bias is not None:

b_conv = conv.bias


b_conv = torch.zeros(conv.weight.size(0))

b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))

fusedconv.bias.copy_(b_conv + b_bn)

return fusedconv


示例2: _mix_rbf_kernel

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def _mix_rbf_kernel(X, Y, sigma_list):

assert(X.size(0) == Y.size(0))

m = X.size(0)

Z = torch.cat((X, Y), 0)

ZZT = torch.mm(Z, Z.t())

diag_ZZT = torch.diag(ZZT).unsqueeze(1)

Z_norm_sqr = diag_ZZT.expand_as(ZZT)

exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

K = 0.0

for sigma in sigma_list:

gamma = 1.0 / (2 * sigma**2)

K += torch.exp(-gamma * exponent)

return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)


示例3: construct_diag

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def construct_diag(x: torch.Tensor):


Constructs a diagonal matrix based on batched data. Solution found here:


Do note that it only considers the last axis.

:param x: The tensor


if x.dim() < 1:

return x

elif x.shape[-1] < 2:

return x.unsqueeze(-1)

elif x.dim() < 2:

return torch.diag(x)

b = torch.eye(x.size(-1), device=x.device)

c = x.unsqueeze(-1).expand(*x.size(), x.size(-1))

return c * b


示例4: test_UnscentedTransform2D

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def test_UnscentedTransform2D(self):

# ===== 2D model ===== #

mat = torch.eye(2)

scale = torch.diag(mat)

norm = Normal(0., 1.)

mvn = MultivariateNormal(torch.zeros(2), torch.eye(2))

mvnlinear = AffineProcess((fmvn, g), (mat, scale), mvn, mvn)

mvnoblinear = AffineObservations((fomvn, gomvn), (1.,), norm)

mvnmodel = StateSpaceModel(mvnlinear, mvnoblinear)

# ===== Perform unscented transform ===== #

uft = UnscentedFilterTransform(mvnmodel)

res = uft.initialize(3000)

p = uft.predict(res)

c = uft.correct(0., p)

assert isinstance(c.x_dist(), MultivariateNormal) and c.x_dist().mean.shape == torch.Size([3000, 2])


示例5: __init__

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def __init__(self,in_channel):





















示例6: forward

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def forward(self, input):

laplacian = input.exp() + self.eps

output = input.clone()

for b in range(input.size(0)):

lap = laplacian[b].masked_fill(

Variable(torch.eye(input.size(1)).cuda().ne(0)), 0)

lap = -lap + torch.diag(lap.sum(0))

# store roots on diagonal

lap[0] = input[b].diag().exp()

inv_laplacian = lap.inverse()

factor = inv_laplacian.diag().unsqueeze(1)\

.expand_as(input[b]).transpose(0, 1)

term1 = input[b].exp().mul(factor).clone()

term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone()

term1[:, 0] = 0

term2[0] = 0

output[b] = term1 - term2

roots_output = input[b].diag().exp().mul(

inv_laplacian.transpose(0, 1)[0])

output[b] = output[b] + torch.diag(roots_output)

return output


示例7: __expm__

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def __expm__(self, matrix, symmetric):

r"""Calculates matrix exponential.


matrix (Tensor): Matrix to take exponential of.

symmetric (bool): Specifies whether the matrix is symmetric.

:rtype: (:class:`Tensor`)


if symmetric:

e, V = torch.symeig(matrix, eigenvectors=True)

diff_mat = V @ torch.diag(e.exp()) @ V.t()


diff_mat_np = expm(matrix.cpu().numpy())

diff_mat = torch.Tensor(diff_mat_np).to(matrix.device)

return diff_mat


示例8: regularizer_orth2

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def regularizer_orth2(m):


# ----------------------------------------

# Applies regularization to the training by performing the

# orthogonalization technique described in the paper

# This function is to be called by the torch.nn.Module.apply() method,

# which applies svd_orthogonalization() to every layer of the model.

# usage: net.apply(regularizer_orth2)

# ----------------------------------------


classname = m.__class__.__name__

if classname.find('Conv') != -1:

w = m.weight.data.clone()

c_out, c_in, f1, f2 = w.size()

# dtype = m.weight.data.type()

w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)

u, s, v = torch.svd(w)

s_mean = s.mean()

s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4

s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4

w = torch.mm(torch.mm(u, torch.diag(s)), v.t())

m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)




示例9: get_loadings

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def get_loadings(self) -> np.ndarray:

"""Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder."""

# This is BW, where B is diag(b) batch norm, W is weight matrix

if self.use_batch_norm is True:

w = self.decoder.factor_regressor.fc_layers[0][0].weight

bn = self.decoder.factor_regressor.fc_layers[0][1]

sigma = torch.sqrt(bn.running_var + bn.eps)

gamma = bn.weight

b = gamma / sigma

bI = torch.diag(b)

loadings = torch.matmul(bI, w)


loadings = self.decoder.factor_regressor.fc_layers[0][0].weight

loadings = loadings.detach().cpu().numpy()

if self.n_batch > 1:

loadings = loadings[:, : -self.n_batch]

return loadings


示例10: _regularizer

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def _regularizer(self, mu_z, std_z):

kld_z = self.z_prior_stdv.log() - std_z.log() + (std_z ** 2 + (mu_z.pow(2) - self.z_prior_mean)) /\

(2 * self.z_prior_stdv.pow(2)) - 0.5

regularizer_loss = kld_z.sum()

regularizer_loss = self.beta * regularizer_loss

cov_mu_z = self._get_covariance_mu_z(mu_z)

if self.mode == "i":

dipvae_regularizer_loss = self._get_dipvae_regularizer(cov_mu_z, self.lambda_offdiag, self.lambda_diag)

elif self.mode == "ii":

cov_z = cov_mu_z + torch.mean(torch.diag(std_z**2), dim=0)

dipvae_regularizer_loss = self._get_dipvae_regularizer(cov_z, self.lambda_offdiag, self.lambda_diag)


raise NotImplementedError("Unsupported dipvae mode.")

return regularizer_loss + dipvae_regularizer_loss


示例11: __init__

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def __init__(self, num_inputs):

super(LUInvertibleMM, self).__init__()

self.W = torch.Tensor(num_inputs, num_inputs)


self.L_mask = torch.tril(torch.ones(self.W.size()), -1)

self.U_mask = self.L_mask.t().clone()

P, L, U = sp.linalg.lu(self.W.numpy())

self.P = torch.from_numpy(P)

self.L = nn.Parameter(torch.from_numpy(L))

self.U = nn.Parameter(torch.from_numpy(U))

S = np.diag(U)

sign_S = np.sign(S)

log_S = np.log(abs(S))

self.sign_S = torch.from_numpy(sign_S)

self.log_S = nn.Parameter(torch.from_numpy(log_S))

self.I = torch.eye(self.L.size(0))


示例12: forward

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def forward(self, inputs, cond_inputs=None, mode='direct'):

if str(self.L_mask.device) != str(self.L.device):

self.L_mask = self.L_mask.to(self.L.device)

self.U_mask = self.U_mask.to(self.L.device)

self.I = self.I.to(self.L.device)

self.P = self.P.to(self.L.device)

self.sign_S = self.sign_S.to(self.L.device)

L = self.L * self.L_mask + self.I

U = self.U * self.U_mask + torch.diag(

self.sign_S * torch.exp(self.log_S))

W = self.P @ L @ U

if mode == 'direct':

return inputs @ W, self.log_S.sum().unsqueeze(0).unsqueeze(

0).repeat(inputs.size(0), 1)


return inputs @ torch.inverse(

W), -self.log_S.sum().unsqueeze(0).unsqueeze(0).repeat(

inputs.size(0), 1)


示例13: test_np

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def test_np():


nx, nineq, neq = 4, 6, 7

Q = npr.randn(nx, nx)

G = npr.randn(nineq, nx)

A = npr.randn(neq, nx)

D = np.diag(npr.rand(nineq))

K_ = np.bmat((

(Q, np.zeros((nx, nineq)), G.T, A.T),

(np.zeros((nineq, nx)), D, np.eye(nineq), np.zeros((nineq, neq))),

(G, np.eye(nineq), np.zeros((nineq, nineq + neq))),

(A, np.zeros((neq, nineq + nineq + neq)))


K = block((

(Q, 0, G.T, A.T),

(0, D, 'I', 0),

(G, 'I', 0, 0),

(A, 0, 0, 0)


assert np.allclose(K_, K)


示例14: loss_l2

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def loss_l2(self, l2=0):

"""L2 loss centered around mu_init, scaled optionally per-source.

In other words, diagonal Tikhonov regularization,


where D is diagonal.


- l2: A float or np.array representing the per-source regularization

strengths to use


if isinstance(l2, (int, float)):

D = l2 * torch.eye(self.d)


D = torch.diag(torch.from_numpy(l2))

# Note that mu is a matrix and this is the *Frobenius norm*

return torch.norm(D @ (self.mu - self.mu_init)) ** 2


示例15: _set_class_balance

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def _set_class_balance(self, class_balance, Y_dev):

"""Set a prior for the class balance

In order of preference:

1) Use user-provided class_balance

2) Estimate balance from Y_dev

3) Assume uniform class distribution


if class_balance is not None:

self.p = np.array(class_balance)

elif Y_dev is not None:

class_counts = Counter(Y_dev)

sorted_counts = np.array([v for k, v in sorted(class_counts.items())])

self.p = sorted_counts / sum(sorted_counts)


self.p = (1 / self.k) * np.ones(self.k)

self.P = torch.diag(torch.from_numpy(self.p)).float()


示例16: __init__

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def __init__(self, params, eps=1e-2):

super(SolveNewsvendor, self).__init__()

k = len(params['d'])

self.Q = Variable(torch.diag(torch.Tensor(

[params['c_quad']] + [params['b_quad']]*k + [params['h_quad']]*k)) \


self.p = Variable(torch.Tensor(

[params['c_lin']] + [params['b_lin']]*k + [params['h_lin']]*k) \


self.G = Variable(torch.cat([

torch.cat([-torch.ones(k,1), -torch.eye(k), torch.zeros(k,k)], 1),

torch.cat([torch.ones(k,1), torch.zeros(k,k), -torch.eye(k)], 1),

-torch.eye(1 + 2*k)], 0).cuda())

self.h = Variable(torch.Tensor(

np.concatenate([-params['d'], params['d'], np.zeros(1+ 2*k)])).cuda())

self.one = Variable(torch.Tensor([1])).cuda()

self.eps_eye = eps * Variable(torch.eye(1 + 2*k).cuda()).unsqueeze(0)


示例17: forward

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def forward(self, y):

nBatch, k = y.size()

Q_scale = torch.cat([torch.diag(torch.cat(

[self.one, y[i], y[i]])).unsqueeze(0) for i in range(nBatch)], 0)

Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale)

p_scale = torch.cat([Variable(torch.ones(nBatch,1).cuda()), y, y], 1)

p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale)

G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))

h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))

e = Variable(torch.Tensor().cuda()).double()

out = QPFunction(verbose=False)\

(Q.double(), p.double(), G.double(), h.double(), e, e).float()

return out[:,:1]


示例18: forward

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def forward(self, y):

nBatch, k = y.size()

eps2 = 1e-8

Q_scale = torch.cat([torch.diag(torch.cat(

[self.one, y[i]+eps2, y[i]+eps2])).unsqueeze(0) for i in range(nBatch)], 0)

Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale)

p_scale = torch.cat([Variable(torch.ones(nBatch,1).cuda()), y, y], 1)

p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale)

G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))

h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))

e = Variable(torch.Tensor().cuda()).double()

out = QPFunction(verbose=False)\

(Q.double(), p.double(), G.double(), h.double(), e, e).float()

return out[:,:1]


示例19: symsqrt

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32):

"""Symmetric square root of a positive semi-definite matrix.

See https://github.com/pytorch/pytorch/issues/25481"""

s, u = torch.symeig(a, eigenvectors=True)

cond_dict = {torch.float32: 1e3 * 1.1920929e-07, torch.float64: 1E6 * 2.220446049250313e-16}

if cond in [None, -1]:

cond = cond_dict[dtype]

above_cutoff = (abs(s) > cond * torch.max(abs(s)))

psigma_diag = torch.sqrt(s[above_cutoff])

u = u[:, above_cutoff]

B = u @ torch.diag(psigma_diag) @ u.t()

if return_rank:

return B, len(psigma_diag)


return B


示例20: forward

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def forward(self, input):

laplacian = input.exp() + self.eps

output = input.clone()

for b in range(input.size(0)):

lap = laplacian[b].masked_fill(

torch.eye(input.size(1)).cuda().ne(0), 0)

lap = -lap + torch.diag(lap.sum(0))

# store roots on diagonal

lap[0] = input[b].diag().exp()

inv_laplacian = lap.inverse()

factor = inv_laplacian.diag().unsqueeze(1)\

.expand_as(input[b]).transpose(0, 1)

term1 = input[b].exp().mul(factor).clone()

term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone()

term1[:, 0] = 0

term2[0] = 0

output[b] = term1 - term2

roots_output = input[b].diag().exp().mul(

inv_laplacian.transpose(0, 1)[0])

output[b] = output[b] + torch.diag(roots_output)

return output


示例21: normalize_adj_tensor

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def normalize_adj_tensor(adj, sparse=False):

"""Normalize adjacency tensor matrix.


device = torch.device("cuda" if adj.is_cuda else "cpu")

if sparse:

# TODO if this is too slow, uncomment the following code,

# but you need to install torch_scatter

# return normalize_sparse_tensor(adj)

adj = to_scipy(adj)

mx = normalize_adj(adj)

return sparse_mx_to_torch_sparse_tensor(mx).to(device)


mx = adj + torch.eye(adj.shape[0]).to(device)

rowsum = mx.sum(1)

r_inv = rowsum.pow(-1/2).flatten()

r_inv[torch.isinf(r_inv)] = 0.

r_mat_inv = torch.diag(r_inv)

mx = r_mat_inv @ mx

mx = mx @ r_mat_inv

return mx


示例22: degree_normalize_adj_tensor

​點讚 6

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def degree_normalize_adj_tensor(adj, sparse=True):



device = torch.device("cuda" if adj.is_cuda else "cpu")

if sparse:

# return degree_normalize_sparse_tensor(adj)

adj = to_scipy(adj)

mx = degree_normalize_adj(adj)

return sparse_mx_to_torch_sparse_tensor(mx).to(device)


mx = adj + torch.eye(adj.shape[0]).to(device)

rowsum = mx.sum(1)

r_inv = rowsum.pow(-1).flatten()

r_inv[torch.isinf(r_inv)] = 0.

r_mat_inv = torch.diag(r_inv)

mx = r_mat_inv @ mx

return mx


示例23: prox_nuclear_truncated_2

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def prox_nuclear_truncated_2(self, data, alpha, k=50):

import tensorly as tl


U, S, V = tl.truncated_svd(data.cpu(), n_eigenvecs=k)

U, S, V = torch.FloatTensor(U).cuda(), torch.FloatTensor(S).cuda(), torch.FloatTensor(V).cuda()

self.nuclear_norm = S.sum()

# print("nuclear norm: %.4f" % self.nuclear_norm)

S = torch.clamp(S-alpha, min=0)

indices = torch.tensor(range(0, U.shape[0]),range(0, U.shape[0])).cuda()

values = S

diag_S = torch.sparse.FloatTensor(indices, values, torch.Size(U.shape))

# diag_S = torch.diag(torch.clamp(S-alpha, min=0))

U = torch.spmm(U, diag_S)

V = torch.matmul(U, V)

return V


示例24: feature_smoothing

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def feature_smoothing(self, adj, X):

adj = (adj.t() + adj)/2

rowsum = adj.sum(1)

r_inv = rowsum.flatten()

D = torch.diag(r_inv)

L = D - adj

r_inv = r_inv + 1e-3

r_inv = r_inv.pow(-1/2).flatten()

r_inv[torch.isinf(r_inv)] = 0.

r_mat_inv = torch.diag(r_inv)

# L = r_mat_inv @ L

L = r_mat_inv @ L @ r_mat_inv

XLXT = torch.matmul(torch.matmul(X.t(), L), X)

loss_smooth_feat = torch.trace(XLXT)

return loss_smooth_feat


示例25: zca_matrix

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def zca_matrix(data_tensor):


Helper function: compute ZCA whitening matrix across a dataset ~ (N, C, H, W).


# 1. flatten dataset:

X = data_tensor.view(data_tensor.shape[0], -1)

# 2. zero-center the matrix:

X = rescale(X, -1., 1.)

# 3. compute covariances:

cov = torch.t(X) @ X

# 4. compute ZCA(X) == U @ (diag(1/S)) @ torch.t(V) where U, S, V = SVD(cov):

U, S, V = torch.svd(cov)

return (U @ torch.diag(torch.reciprocal(S)) @ torch.t(V))


示例26: forward

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def forward(self, xs):


Forward pass through all invertible coupling layers.


* xs: float tensor of shape (B,dim).


* ys: float tensor of shape (B,dim).


ys = self.layer1(xs)

ys = self.layer2(ys)

ys = self.layer3(ys)

ys = self.layer4(ys)

ys = torch.matmul(ys, torch.diag(torch.exp(self.scaling_diag)))

return ys


示例27: greedy_decoder

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def greedy_decoder(arc_matrix, mask=None):


貪心解碼方式, 輸入圖, 輸出貪心解碼的parsing結果, 不保證合法的構成樹

:param arc_matrix: [batch, seq_len, seq_len] 輸入圖矩陣

:param mask: [batch, seq_len] 輸入圖的padding mask, 有內容的部分為 1, 否則為 0.

若為 ``None`` 時, 默認為全1向量. Default: ``None``

:return heads: [batch, seq_len] 每個元素在樹中對應的head(parent)預測結果


_, seq_len, _ = arc_matrix.shape

matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))

flip_mask = mask.eq(False)

matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)

_, heads = torch.max(matrix, dim=2)

if mask is not None:

heads *= mask.long()

return heads


示例28: preprocess

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def preprocess(A):

# Get size of the adjacency matrix

size = A.size(1)

# Get the degrees for each node

degrees = torch.sum(A, dim=2)

# Create diagonal matrix D from the degrees of the nodes

D = Variable(torch.zeros(A.size(0),A.size(1),A.size(2))).cuda()

for i in range(D.size(0)):

D[i, :, :] = torch.diag(torch.pow(degrees[i,:], -0.5))

# Cholesky decomposition of D

# D = np.linalg.cholesky(D)

# Inverse of the Cholesky decomposition of D

# D = np.linalg.inv(D)

# Create an identity matrix of size x size

# Create A hat

# Return A_hat

A_normal = torch.matmul(torch.matmul(D,A), D)

# print(A_normal)

return A_normal

# a sequential GCN model, GCN with n layers


示例29: _mmd2

# 需要導入模塊: import torch [as 別名]

# 或者: from torch import diag [as 別名]

def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):

m = K_XX.size(0) # assume X, Y are same shape

# Get the various sums of kernels that we'll use

# Kts drop the diagonal, but we don't need to compute them explicitly

if const_diagonal is not False:

diag_X = diag_Y = const_diagonal

sum_diag_X = sum_diag_Y = m * const_diagonal


diag_X = torch.diag(K_XX) # (m,)

diag_Y = torch.diag(K_YY) # (m,)

sum_diag_X = torch.sum(diag_X)

sum_diag_Y = torch.sum(diag_Y)

Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X

Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y

K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e

Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e

Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e

K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e

if biased:

mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)

+ (Kt_YY_sum + sum_diag_Y) / (m * m)

- 2.0 * K_XY_sum / (m * m))


mmd2 = (Kt_XX_sum / (m * (m - 1))

+ Kt_YY_sum / (m * (m - 1))

- 2.0 * K_XY_sum / (m * m))

return mmd2



