Word2vec:Tensorflow实战

本文代码可视化过程:传送门


引言

前面我曾系统分析过Word2vec的理论及工具包的使用,那么在深度学习框架tensorflow中如何实现Word2Vec呢?接下来,我们将介绍tensorflow中word2vec简化版实现方法。


Minimalistic Implementation

Step1 下载数据

这里我们直接从这个网站上下载文本数据集:http://mattmahoney.net/dc/,代码如下:

url = 'http://mattmahoney.net/dc/'# pylint: disable=redefined-outer-name
def maybe_download(filename, expected_bytes):"""Download a file if not present, and make sure it's the right size."""local_filename = os.path.join(gettempdir(), filename)if not os.path.exists(local_filename):local_filename, _ = urllib.request.urlretrieve(url + filename,local_filename)statinfo = os.stat(local_filename)if statinfo.st_size == expected_bytes:print('Found and verified', filename)else:print(statinfo.st_size)raise Exception('Failed to verify ' + local_filename +'. Can you get to it with a browser?')return local_filenamefilename = maybe_download('text8.zip', 31344016)

数据集大约30M左右,默认下载位置为/tmp/
下载完数据后,将文本转换成字符串:

# Read the data into a list of strings.
def read_data(filename):"""Extract the first file enclosed in a zip file as a list of words."""with zipfile.ZipFile(filename) as f:data = tf.compat.as_str(f.read(f.namelist()[0])).split()return datavocabulary = read_data(filename)
print('Data size', len(vocabulary))

Step2 构建词汇表,并将常见词汇用’UNK’代替

# Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = 50000def build_dataset(words, n_words):"""Process raw inputs into a dataset."""count = [['UNK', -1]]count.extend(collections.Counter(words).most_common(n_words - 1))dictionary = dict()for word, _ in count:dictionary[word] = len(dictionary)data = list()unk_count = 0for word in words:index = dictionary.get(word, 0)if index == 0:  # dictionary['UNK']unk_count += 1data.append(index)count[0][1] = unk_countreversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))return data, count, dictionary, reversed_dictionary# Filling 4 global variables:
# data - list of codes (integers from 0 to vocabulary_size-1).
#   This is the original text but words are replaced by their codes
# count - map of words(strings) to count of occurrences
# dictionary - map of words(strings) to their codes(integers)
# reverse_dictionary - maps codes(integers) to words(strings)
data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,vocabulary_size)
del vocabulary  # Hint to reduce memory.
print('Most common words (+UNK)', count[:5])
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])data_index = 0

这里的常见单词主要有’the’,‘of’,‘and’,‘one’

Most common words (+UNK) [[‘UNK’, 418391], (‘the’, 1061396), (‘of’, 593677), (‘and’, 416629), (‘one’, 411764)]

Step3 为skip-gram模型生成一个training batch

# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):global data_indexassert batch_size % num_skips == 0assert num_skips <= 2 * skip_windowbatch = np.ndarray(shape=(batch_size), dtype=np.int32)labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)span = 2 * skip_window + 1  # [ skip_window target skip_window ]buffer = collections.deque(maxlen=span)if data_index + span > len(data):data_index = 0buffer.extend(data[data_index:data_index + span])data_index += spanfor i in range(batch_size // num_skips):context_words = [w for w in range(span) if w != skip_window]words_to_use = random.sample(context_words, num_skips)for j, context_word in enumerate(words_to_use):batch[i * num_skips + j] = buffer[skip_window]labels[i * num_skips + j, 0] = buffer[context_word]if data_index == len(data):buffer[:] = data[:span]data_index = spanelse:buffer.append(data[data_index])data_index += 1# Backtrack a little bit to avoid skipping words in the end of a batchdata_index = (data_index + len(data) - span) % len(data)return batch, labelsbatch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):print(batch[i], reverse_dictionary[batch[i]],'->', labels[i, 0], reverse_dictionary[labels[i, 0]])

Step4 构建并训练skip-gram模型

# Step 4: Build and train a skip-gram model.batch_size = 128
embedding_size = 128  # Dimension of the embedding vector.
skip_window = 1       # How many words to consider left and right.
num_skips = 2         # How many times to reuse an input to generate a label.
num_sampled = 64      # Number of negative examples to sample.# We pick a random validation set to sample nearest neighbors. Here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent. These 3 variables are used only for
# displaying model accuracy, they don't affect calculation.
valid_size = 16     # Random set of words to evaluate similarity on.
valid_window = 100  # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)graph = tf.Graph()with graph.as_default():# Input data.train_inputs = tf.placeholder(tf.int32, shape=[batch_size])train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])valid_dataset = tf.constant(valid_examples, dtype=tf.int32)# Ops and variables pinned to the CPU because of missing GPU implementationwith tf.device('/cpu:0'):# Look up embeddings for inputs.embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))embed = tf.nn.embedding_lookup(embeddings, train_inputs)# Construct the variables for the NCE lossnce_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size],stddev=1.0 / math.sqrt(embedding_size)))nce_biases = tf.Variable(tf.zeros([vocabulary_size]))# Compute the average NCE loss for the batch.# tf.nce_loss automatically draws a new sample of the negative labels each# time we evaluate the loss.# Explanation of the meaning of NCE loss:#   http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weights,biases=nce_biases,labels=train_labels,inputs=embed,num_sampled=num_sampled,num_classes=vocabulary_size))# Construct the SGD optimizer using a learning rate of 1.0.optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)# Compute the cosine similarity between minibatch examples and all embeddings.norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))normalized_embeddings = embeddings / normvalid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset)similarity = tf.matmul(valid_embeddings, normalized_embeddings, transpose_b=True)# Add variable initializer.init = tf.global_variables_initializer()

Step5 开始训练

# Step 5: Begin training.
num_steps = 100001with tf.Session(graph=graph) as session:# We must initialize all variables before we use them.init.run()print('Initialized')average_loss = 0for step in xrange(num_steps):batch_inputs, batch_labels = generate_batch(batch_size, num_skips, skip_window)feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}# We perform one update step by evaluating the optimizer op (including it# in the list of returned values for session.run()_, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)average_loss += loss_valif step % 2000 == 0:if step > 0:average_loss /= 2000# The average loss is an estimate of the loss over the last 2000 batches.print('Average loss at step ', step, ': ', average_loss)average_loss = 0# Note that this is expensive (~20% slowdown if computed every 500 steps)if step % 10000 == 0:sim = similarity.eval()for i in xrange(valid_size):valid_word = reverse_dictionary[valid_examples[i]]top_k = 8  # number of nearest neighborsnearest = (-sim[i, :]).argsort()[1:top_k + 1]log_str = 'Nearest to %s:' % valid_wordfor k in xrange(top_k):close_word = reverse_dictionary[nearest[k]]log_str = '%s %s,' % (log_str, close_word)print(log_str)final_embeddings = normalized_embeddings.eval()

这里每隔10000次就打印最近的训练的得到的单词

Step6 可视化

# Step 6: Visualize the embeddings.# pylint: disable=missing-docstring
# Function to draw visualization of distance between embeddings.
def plot_with_labels(low_dim_embs, labels, filename):assert low_dim_embs.shape[0] >= len(labels), 'More labels than embeddings'plt.figure(figsize=(18, 18))  # in inchesfor i, label in enumerate(labels):x, y = low_dim_embs[i, :]plt.scatter(x, y)plt.annotate(label,xy=(x, y),xytext=(5, 2),textcoords='offset points',ha='right',va='bottom')plt.savefig(filename)try:# pylint: disable=g-import-not-at-topfrom sklearn.manifold import TSNEimport matplotlib.pyplot as plttsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')plot_only = 500low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])labels = [reverse_dictionary[i] for i in xrange(plot_only)]plot_with_labels(low_dim_embs, labels, os.path.join(gettempdir(), 'tsne.png'))except ImportError as ex:print('Please install sklearn, matplotlib, and scipy to show embeddings.')print(ex)

可视化结果:

Word2vec-Tensorflow实战相关推荐

  1. 免费教材丨第55期:Python机器学习实践指南、Tensorflow 实战Google深度学习框架

    小编说  时间过的好快啊,小伙伴们是不是都快进入寒假啦?但是学习可不要落下哦!  本期教材  本期为大家发放的教材为:<Python机器学习实践指南>.<Tensorflow 实战G ...

  2. Tensorflow实战之下载MNIST数据,自动分成train, validation和test三个数据集

    TensorFlow 实战Google深度学习框架 第2版 ,郑泽宇之P96.下载MNIST数据,自动分成train, validation和test三个数据集,源码如下: #!/usr/bin/en ...

  3. TensorFlow实战-AlexNet

    1 # 导入数据 2 from tensorflow.examples.tutorials.mnist import input_data 3 # 读取数据 4 mnist=input_data.re ...

  4. Tensorflow实战系列之五:

    打算写实例分割的实战,类似mask-rcnn. Tensorflow实战先写五个系列吧,后面新的技术再添加~~ 转载于:https://www.cnblogs.com/wmr95/p/8846887. ...

  5. 《Tensorflow 实战google深度学习框架》第二版源代码

    <<Tensorflow 实战google深度学习框架–第二版>> 完整资料github地址: https://github.com/caicloud/tensorflow-t ...

  6. 06.图像识别与卷积神经网络------《Tensorflow实战Google深度学习框架》笔记

    一.图像识别问题简介及经典数据集 图像识别问题希望借助计算机程序来处理.分析和理解图片中的内容,使得计算机可以从图片中自动识别各种不同模式的目标和对象.图像识别问题作为人工智能的一个重要领域,在最近几 ...

  7. Word2vec原理+实战学习笔记(二)

    来源:投稿 作者:阿克西 编辑:学姐 前篇:Word2vec原理+实战学习笔记(一)​​​​​​​ 视频链接:https://ai.deepshare.net/detail/p_5ee62f90022 ...

  8. tensorflow63 《深度学习原理与TensorFlow实战》03 Hello TensorFlow

    00 基本信息 <深度学习原理与TensorFlow实战>书中涉及到的代码主要来源于: A:Tensorflow/TensorflowModel/TFLean的样例, B:https:// ...

  9. 【TensorFlow实战笔记】卷积神经网络CNN实战-cifar10数据集(tensorboard可视化)

    IDE:pycharm Python: Python3.6 OS: win10 tf : CPU版本 代码可在github中下载,欢迎star,谢谢 CNN-CIFAR-10 一.CIFAR10数据集 ...

  10. TensorFlow实战之Softmax Regression识别手写数字

       本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.相关概念 1.MNIST MNIST(Mixed N ...

最新文章

  1. 自行车为什么前轮和后轮受到的摩擦力相反呢 自行车前轮后轮转动方向一样 自行车运动原理...
  2. kappa一致性检验教程_SPSS在线_SPSSAU_Kappa一致性检验
  3. 使用bootstrap-table等自动使用ajax地址载入数据的插件的数据设计建议
  4. 博本 微型 电脑 linux,博本电脑下载与安装Windows7iso镜像系统教程
  5. 指标命名随意,增加沟通成本,套用这个公式,学会规范化定义指标
  6. Java 基础(十九)代理
  7. 关于心理的二十五种倾向(查理·芒格)-3
  8. 第三届中医药文化传承与技能发展大会召开助推中医药文化传承创新
  9. 机载激光雷达的应用现状及发展趋势
  10. IB物理应该怎么学?
  11. 刚接触js不久,自己写的banner幻灯片效果。
  12. win10开机未能正确启动_设置华硕电脑定时开机只需两步!
  13. pcolor和surf画出的图形转存为eps或pdf格式出现横竖斜白色虚线网格,完美去掉!!!
  14. 用zrender制作一个基础的绘图板,绘图板可用于组态界面的基础性开发
  15. 深圳东西冲穿越游玩攻略
  16. IP地址非你在商户平台设置的可用IP地址
  17. 服务器千兆网卡显示百兆,windows10系统查看网卡是千兆还是百兆的方法
  18. C语言基础知识点汇总(1)
  19. autocad map 3d 2022 地图绘制软件
  20. 超导体计算机采用纳米技术研制的计算机,在高温超导体加持下,量子效应的量子计算机,或将迎来重大突破!...

热门文章

  1. 如何下载百度指数数据到Excel
  2. woff字体反爬实战,10分钟就能学会(ttf字体同理)
  3. HTML+CSS大作业:基于HMTL校园学校网页设计题材【我的学校网站】
  4. 史上最简单的虚拟机搭建软路由 NAS 家庭媒体中心
  5. 百度深度学习--手写数字识别之数据处理
  6. MySQL:开窗函数
  7. 网络:认识网络字节序
  8. html5桌面系统,基于HTML5的IVI桌面系统及本地功能扩展研究实现
  9. 打印机 树莓派安装cpus_用树莓派搭建网络打印机
  10. 【计算机网络实验 第一卷:使劲学 加相关网络知识---随着深入会不断补充】