tensorflow之word2vec_basic代码研究

源代码网址: https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/examples/tutorials/word2vec/word2vec_basic.py简书上有一篇此代码的详解,图文并茂,可直接看这篇详解: http://www.jianshu.com/p/f682066f0586

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Basic word2vec example."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math
import os
import random
import zipfile

import numpy as np
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

# Step 1: Download the data.
url = ‘http://mattmahoney.net/dc/‘

def maybe_download(filename, expected_bytes):
  """Download a file if not present, and make sure it‘s the right size."""
  if not os.path.exists(filename):
    filename, _ = urllib.request.urlretrieve(url + filename, filename)
  statinfo = os.stat(filename)
  if statinfo.st_size == expected_bytes:
    print(‘Found and verified‘, filename)
  else:
    print(statinfo.st_size)
    raise Exception(
        ‘Failed to verify ‘ + filename + ‘. Can you get to it with a browser?‘)
  return filename

filename = maybe_download(‘text8.zip‘, 31344016)

# Read the data into a list of strings.
def read_data(filename):
  """Extract the first file enclosed in a zip file as a list of words."""
  with zipfile.ZipFile(filename) as f:
    data = tf.compat.as_str(f.read(f.namelist()[0])).split()
  return data

vocabulary = read_data(filename)
print(‘Data size‘, len(vocabulary))

# Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = 50000

‘‘‘
input:
words - the original word list
n_words - the number of used words

output:
data - a list with the same length of input words
every element in the list is the value of the corresponding word in dictionary
or the position in count or dictionary
count - a matrix with n_words rows and two columns,
the first column corresponds to the word,
the second column corresponds to its frequency in input words
the first row in count is [‘UNK‘, *]
the other rows are in descending order of the sencond column
dictionary - key-value map, key is the word, value is its position in count or dictionary
reversed_dictionary - reverse the key-value in dictionary
‘‘‘

def build_dataset(words, n_words):
  """Process raw inputs into a dataset."""
  count = [[‘UNK‘, -1]]
  count.extend(collections.Counter(words).most_common(n_words - 1))
  dictionary = dict()
  for word, _ in count:
    dictionary[word] = len(dictionary)
  data = list()
  unk_count = 0
  for word in words:
    if word in dictionary:
      index = dictionary[word]
    else:
      index = 0  # dictionary[‘UNK‘]
      unk_count += 1
    data.append(index)
  count[0][1] = unk_count
  reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
  return data, count, dictionary, reversed_dictionary

data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
                                                            vocabulary_size)
del vocabulary  # Hint to reduce memory.
print(‘Most common words (+UNK)‘, count[:5])
print(‘Sample data‘, data[:10], [reverse_dictionary[i] for i in data[:10]])

data_index = 0

‘‘‘
convert data to batch and labels
the values in batch and labels are the positions of the corresponding words
‘‘‘

# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):
  global data_index
  assert batch_size % num_skips == 0
  assert num_skips <= 2 * skip_window
  batch = np.ndarray(shape=(batch_size), dtype=np.int32)
  labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
  span = 2 * skip_window + 1  # [ skip_window target skip_window ]
  buffer = collections.deque(maxlen=span)
  for _ in range(span):
    buffer.append(data[data_index])
    data_index = (data_index + 1) % len(data)
  for i in range(batch_size // num_skips):
    target = skip_window  # target label at the center of the buffer
    targets_to_avoid = [skip_window]
    for j in range(num_skips):
      while target in targets_to_avoid:
        target = random.randint(0, span - 1)
      targets_to_avoid.append(target)
      batch[i * num_skips + j] = buffer[skip_window]
      labels[i * num_skips + j, 0] = buffer[target]
    buffer.append(data[data_index])
    data_index = (data_index + 1) % len(data)
  # Backtrack a little bit to avoid skipping words in the end of a batch
  data_index = (data_index + len(data) - span) % len(data)
  return batch, labels

batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):
  print(batch[i], reverse_dictionary[batch[i]],
        ‘->‘, labels[i, 0], reverse_dictionary[labels[i, 0]])

# Step 4: Build and train a skip-gram model.

batch_size = 128
embedding_size = 128  # Dimension of the embedding vector.
skip_window = 1       # How many words to consider left and right.
num_skips = 2         # How many times to reuse an input to generate a label.

# We pick a random validation set to sample nearest neighbors. Here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent.
valid_size = 16     # Random set of words to evaluate similarity on.
valid_window = 100  # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)
num_sampled = 64    # Number of negative examples to sample.

graph = tf.Graph()

with graph.as_default():

  # Input data.
  train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
  train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
  valid_dataset = tf.constant(valid_examples, dtype=tf.int32)

  # Ops and variables pinned to the CPU because of missing GPU implementation
  with tf.device(‘/cpu:0‘):

‘‘‘
Generate initial embeddings using random values
the row of the embeddings is same as vocabulary size
the column of the embeddings is the dimension of the embedding vector
each row of the embedding corresponds to the word in count or dictionary with the same row id
The below embed is the embeddings of train_inputs
‘‘‘

    # Look up embeddings for inputs.
    embeddings = tf.Variable(
        tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
    embed = tf.nn.embedding_lookup(embeddings, train_inputs)

    # Construct the variables for the NCE loss
    nce_weights = tf.Variable(
        tf.truncated_normal([vocabulary_size, embedding_size],
                            stddev=1.0 / math.sqrt(embedding_size)))
    nce_biases = tf.Variable(tf.zeros([vocabulary_size]))

  # Compute the average NCE loss for the batch.
  # tf.nce_loss automatically draws a new sample of the negative labels each
  # time we evaluate the loss.
  loss = tf.reduce_mean(
      tf.nn.nce_loss(weights=nce_weights,
                     biases=nce_biases,
                     labels=train_labels,
                     inputs=embed,
                     num_sampled=num_sampled,
                     num_classes=vocabulary_size))

  # Construct the SGD optimizer using a learning rate of 1.0.
  optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

  # Compute the cosine similarity between minibatch examples and all embeddings.
  norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
  normalized_embeddings = embeddings / norm
  valid_embeddings = tf.nn.embedding_lookup(
      normalized_embeddings, valid_dataset)
  similarity = tf.matmul(
      valid_embeddings, normalized_embeddings, transpose_b=True)

  # Add variable initializer.
  init = tf.global_variables_initializer()

# Step 5: Begin training.
num_steps = 100001

with tf.Session(graph=graph) as session:
  # We must initialize all variables before we use them.
  init.run()
  print(‘Initialized‘)

  average_loss = 0
  for step in xrange(num_steps):
    batch_inputs, batch_labels = generate_batch(
        batch_size, num_skips, skip_window)
    feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}

    # We perform one update step by evaluating the optimizer op (including it
    # in the list of returned values for session.run()
    _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
    average_loss += loss_val

    if step % 2000 == 0:
      if step > 0:
        average_loss /= 2000
      # The average loss is an estimate of the loss over the last 2000 batches.
      print(‘Average loss at step ‘, step, ‘: ‘, average_loss)
      average_loss = 0

    # Note that this is expensive (~20% slowdown if computed every 500 steps)
    if step % 10000 == 0:
      sim = similarity.eval()
      for i in xrange(valid_size):
        valid_word = reverse_dictionary[valid_examples[i]]
        top_k = 8  # number of nearest neighbors
        nearest = (-sim[i, :]).argsort()[1:top_k + 1]
        log_str = ‘Nearest to %s:‘ % valid_word
        for k in xrange(top_k):
          close_word = reverse_dictionary[nearest[k]]
          log_str = ‘%s %s,‘ % (log_str, close_word)
        print(log_str)
  final_embeddings = normalized_embeddings.eval()

# Step 6: Visualize the embeddings.

def plot_with_labels(low_dim_embs, labels, filename=‘tsne.png‘):
  assert low_dim_embs.shape[0] >= len(labels), ‘More labels than embeddings‘
  plt.figure(figsize=(18, 18))  # in inches
  for i, label in enumerate(labels):
    x, y = low_dim_embs[i, :]
    plt.scatter(x, y)
    plt.annotate(label,
                 xy=(x, y),
                 xytext=(5, 2),
                 textcoords=‘offset points‘,
                 ha=‘right‘,
                 va=‘bottom‘)

  plt.savefig(filename)

try:
  # pylint: disable=g-import-not-at-top
  from sklearn.manifold import TSNE
  import matplotlib.pyplot as plt

  tsne = TSNE(perplexity=30, n_components=2, init=‘pca‘, n_iter=5000)
  plot_only = 500
  low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
  labels = [reverse_dictionary[i] for i in xrange(plot_only)]
  plot_with_labels(low_dim_embs, labels)

except ImportError:
  print(‘Please install sklearn, matplotlib, and scipy to show embeddings.‘)
时间: 2024-12-20 15:24:34

tensorflow之word2vec_basic代码研究的相关文章

神经网络caffe框架源码解析--softmax_layer.cpp类代码研究

// Copyright 2013 Yangqing Jia // #include <algorithm> #include <vector> #include "caffe/layer.hpp" #include "caffe/vision_layers.hpp" #include "caffe/util/math_functions.hpp" using std::max; namespace caffe { /**

神经网络caffe框架源码解析--data_layer.cpp类代码研究

dataLayer作为整个网络的输入层, 数据从leveldb中取.leveldb的数据是通过图片转换过来的. 网络建立的时候, datalayer主要是负责设置一些参数,比如batchsize,channels,height,width等. 这次会通过读leveldb一个数据块来获取这些信息. 然后启动一个线程来预先从leveldb拉取一批数据,这些数据是图像数据和图像标签. 正向传播的时候, datalayer就把预先拉取好数据拷贝到指定的cpu或者gpu的内存. 然后启动新线程再预先拉取数

dedecms代码研究二

dedecms代码研究(2)从index开始 现在继续,今天讲的主要是dedecms的入口代码. 先打开index.PHP看看里面是什么吧.打开根目录下的index.php嗯 映入眼帘的是一个if语句.检查/data/common.inc.php是否存在.如果不存在就跳转到安装界面. if(!file_exists(dirname(__FILE__).'/data/common.inc.php')) { header('Location:install/index.php'); exit();

Spring代码研究-前言

好久没有写过博客了 看看上篇,也是唯一的博客是3年前刚工作写的,似乎过去了很久 一次面试,面试官突然问我,为什么要用Spring,我一时语塞,不知道从何说起 呜呜弄弄,Spring提供的DI/IOC,AOP,MVC以及对与Hibernate,JDBC的支持,很方便使用,可以使我们非常方便的编程,把更多的经历放在业务逻辑的设计上 并不自信,因为我觉得我说的不好 当然,这也是这篇博客,以及后续Spring代码分析研究博客产生的原因 工作三年,工作做过ExtJS,Flex,Twaver Flex,An

如何使用TensorFlow Hub和代码示例

任何深度学习框架,为了获得成功,必须提供一系列最先进的模型,以及在流行和广泛接受的数据集上训练的权重,即与训练模型. TensorFlow现在已经提出了一个更好的框架,称为TensorFlow Hub,它非常易于使用且组织良好.使用TensorFlow Hub,您可以通过几行代码导入大型和流行的模型,自信地执行广泛使用的传输学习活动.TensorFlow Hub非常灵活,可以托管您的模型以供其他用户使用.TensorFlow Hub中的这些模型称为模块.在本文中,让我们看看如何使用TensorF

CWMP开源代码研究5——CWMP程序设计思想

声明:本文涉及的开源程序代码学习和研究,严禁用于商业目的. 如有任何问题,欢迎和我交流.(企鹅号:408797506) 本文介绍自己用过的ACS,其中包括开源版(提供下载包)和商业版(仅提供安装包下载,没有源码) 参考: 1) http://www.docin.com/p-1306443672.html 2) http://www.easycwmp.org/ 一. 背景   程序设计的思想来自于easycwmp官网,看过或者用过easycwmp的工程师应该都知道,该开源代码还有商业版,而且价格不

CWMP开源代码研究7——cwmp移植

原创作品,转载请注明出处,严禁非法转载.如有错误,请留言! email:[email protected] 声明:本系列涉及的开源程序代码学习和研究,严禁用于商业目的. 如有任何问题,欢迎和我交流.(企鹅号:408797506) 本篇用到的文件包下载路径:http://download.csdn.net/detail/eryunyong/9735149 一. 环境1.GNU/Linux Centos6.5操作系统2.gcc二. 依赖包的安装1. expat-2.1.01)下载安装包expat-2

dedecms代码研究一

dedecms相信大家一定都知道这个cms系统,功能比较强大,有比较完善的内容发布,还有内容静态化系统,还有就是它有自己独特的标签系统和模板系统.而模板系统也是其他cms系统比较难模仿的的东西,这个东西还是需要一点开发功力和技巧的. 本系列文章就研究一下dedecms的这套系统,挖掘一下看看里面有什么好东西. 建议大家先了解一下dedecms的功能.自己先动手用一下,对系统功能有个大概了解. 本文先带领大家了解一下dedecms的代码和功能架构. 其实,dedecms在架构上没什么应用架构模式可

dedecms代码研究五

上一次留几个疑问: 1)DedeTagParse类LoadTemplet方法. 2)MakeOneTag到底在搞什么. 从DedeTagParse开始前面,我们一直在dedecms的外围,被各种全局变量和各种调用所迷惑,我们抓住了一个关键的线索DedeTagParse类,研究明白它,就可以弄清楚很多东西了. 看看这个NB的DedeTagParse类吧. 嗯,先看构造函数,没什么特别的,就是设置了一堆初始化参数. 接下来就找LoadTemplet方法吧. 找到后,我们发现LoadTemplet方法