GraphSAGE 代码解析 - minibatch.py

class EdgeMinibatchIterator

    """ This minibatch iterator iterates over batches of sampled edges or
    random pairs of co-occuring edges.

    G -- networkx graph
    id2idx -- dict mapping node ids to index in feature tensor
    placeholders -- tensorflow placeholders object
    context_pairs -- if not none, then a list of co-occuring node pairs (from random walks)
    batch_size -- size of the minibatches
    max_degree -- maximum size of the downsampled adjacency lists
    n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model
    fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context
    """

def __init__(self, G, id2idx, placeholders, context_pairs=None, batch_size=100, max_degree=25,

n2v_retrain=False, fixed_n2v=False, **kwargs) 中具体介绍以下:

1 self.nodes = np.random.permutation(G.nodes())
2 # 函数shuffle与permutation都是对原来的数组进行重新洗牌,即随机打乱原来的元素顺序
3 # shuffle直接在原来的数组上进行操作,改变原来数组的顺序,无返回值
4 # permutation不直接在原来的数组上进行操作,而是返回一个新的打乱顺序的数组,并不改变原来的数组。
1 self.adj, self.deg = self.construct_adj()

这里重点看construct_adj()函数。

 1 def construct_adj(self):
 2         adj = len(self.id2idx) *  3             np.ones((len(self.id2idx) + 1, self.max_degree))
 4         # 该矩阵记录训练数据中各节点的邻居节点的编号
 5         # 采样只取max_degree个邻居节点,采样方法见下
 6         # 同样进行了行数加一操作
 7
 8         deg = np.zeros((len(self.id2idx),))
 9         # 该矩阵记录了每个节点的度数
10
11         for nodeid in self.G.nodes():
12             if self.G.node[nodeid][‘test‘] or self.G.node[nodeid][‘val‘]:
13                 continue
14             neighbors = np.array([self.id2idx[neighbor]
15                                   for neighbor in self.G.neighbors(nodeid)
16                                   if (not self.G[nodeid][neighbor][‘train_removed‘])])
17             # Graph.neighbors() Return a list of the nodes connected to the node n.
18             # 在选取邻居节点时进行了筛选,对于G.neighbors(nodeid) 点node的邻居,
19             # 只取该node与neighbor相连的边的train_removed = False的neighbor
20             # 也就是只取不是val, test的节点。
21             # neighbors得到了邻居节点编号数列。
22
23             deg[self.id2idx[nodeid]] = len(neighbors)
24             # deg各位取值为该位对应nodeid的节点的度数,
25             # 也即经过上面筛选后得到的邻居数
26
27             if len(neighbors) == 0:
28                 continue
29             if len(neighbors) > self.max_degree:
30                 neighbors = np.random.choice(
31                     neighbors, self.max_degree, replace=False)
32             # range: neighbors; size = max_degree; replace: replace the origin matrix or not
33             # np.random.choice为选取size大小的数列
34
35             elif len(neighbors) < self.max_degree:
36                 neighbors = np.random.choice(
37                     neighbors, self.max_degree, replace=True)
38             # 经过choice随机选取,得到了固定大小max_degree = 25的直接相连的邻居数列
39
40             adj[self.id2idx[nodeid], :] = neighbors
41            # 把该node的邻居数列,赋值给adj矩阵中对应nodeid位的向量。
42         return adj, deg

construct_test_adj()  函数中,与上不同之处在于,可以直接得到邻居而无需根据val/test/train_removed筛选.

1 neighbors = np.array([self.id2idx[neighbor]
2                           for neighbor in self.G.neighbors(nodeid)])

原文地址:https://www.cnblogs.com/shiyublog/p/9902423.html

时间: 2024-08-30 18:01:56

GraphSAGE 代码解析 - minibatch.py的相关文章

GraphSAGE 代码解析

安装Docker与程序运行 1. requirements.txt Problem: Downloading https://files.pythonhosted.org/packages/69/cb/f5be453359271714c01b9bd06126eaf2e368f1fddfff30818754b5ac2328/funcsigs-1.0.2-py2.py3-none-any.whl Collecting futures==3.2.0 (from -r requirements.txt

OpenStack之虚机热迁移代码解析

OpenStack之虚机热迁移代码解析 话说虚机迁移分为冷迁移以及热迁移,所谓热迁移用度娘的话说即是:热迁移(Live Migration,又叫动态迁移.实时迁移),即虚机保存/恢复(Save/Restore):将整个虚拟机的运行状态完整保存下来,同时可以快速的恢复到原有硬件平台甚至是不同硬件平台上.恢复以后,虚机仍旧平滑运行,用户不会察觉到任何差异.OpenStack的虚机迁移是基于Libvirt实现的,下面来看看Openstack虚机热迁移的具体代码实现. 首先,由API入口进入到nova/

#YOLO_v3代码解析以及相关注意事项

1. 项目介绍 $~~~~~~~$本次YOLO_v3的项目来源于机器之心翻译的项目---从零开始PyTorch项目:YOLO v3目标检测实现以及从零开始 PyTorch 项目:YOLO v3 目标检测实现(下)两部分组成,原版的博客在此Series: YOLO object detector in PyTorch,原始博客的GitHub地址为:ayooshkathuria/pytorch-yolo-v3,最后附上论文的地址:YOLOv3: An Incremental Improvement

Faster RCNN算法代码解析

一. Faster-RCNN代码解释 先看看代码结构: Data: This directory holds (after you download them): Caffe models pre-trained on ImageNet Faster R-CNN models Symlinks to datasets demo 5张图片 scripts 下载模型的脚本 Experiments: logs scripts/faster_rcnn_alt_opt.sh cfgs/faster_rcn

ffmpeg代码解析

void avdevice_register_all(void){    static int initialized;    if (initialized)        return;    initialized = 1;    /* devices */    REGISTER_INOUTDEV(ALSA,             alsa);    REGISTER_INDEV   (AVFOUNDATION,     avfoundation);    REGISTER_INDEV

[nRF51822] 10、基础实验代码解析大全 &#183; 实验15 - RTC

一.实验内容: 配置NRF51822 的RTC0 的TICK 频率为8Hz,COMPARE0 匹配事件触发周期为3 秒,并使能了TICK 和COMPARE0 中断. TICK 中断中驱动指示灯D1 翻转状态, 即指示灯D1 以8Hz 的速率翻转状态 COMPARE0 中断中点亮指示灯D2 二.nRF51822的内部RTC结构: NRF51822 有两个RTC 时钟:RTC0,RTC1.两个RTC 均为24 位,使用LFCLK 低频时钟,并带有12 位分频器,可产生TICK.compare 和溢出

(转)Java二进制指令代码解析

转自http://www.blogjava.net/DLevin/archive/2011/09/13/358497.html Java二进制指令代码解析 Java源码在运行之前都要编译成为字节码格式(如.class文件),然后由ClassLoader将字节码载入运行.在字节码文件中,指令代码只是其中的一部分,里面还记录了字节码文件的编译版本.常量池.访问权限.所有成员变量和成员方法等信息(详见Java字节码格式详解).本文主要简单介绍不同Java指令的功能以及在代码中如何解析二进制指令. Ja

Storm中的LocalState 代码解析

官方的解释这个类为: /** * A simple, durable, atomic K/V database. *Very inefficient*, should only be * used for occasional reads/writes. Every read/write hits disk. */ 简单来理解就是这个类每次读写都会将一个Map<Object, Object>的对象序列化存储到磁盘中,读的时候将其反序列化. 构造函数指定的参数就是你在磁盘中存储的目录,同时也作为

Java二进制指令代码解析

http://www.blogjava.net/DLevin/archive/2011/09/13/358497.html http://blog.csdn.net/sum_rain/article/details/39892219 http://www.blogjava.net/DLevin/archive/2011/09/13/358497.html Java二进制指令代码解析 小注:去年在看<深入解析JVM>书的时候做的一些记录,同时参考了<Java虚拟机规范>.只是对指令的