Finding MNEMON:Reviving Memories of Node Embeddings论文阅读笔记
Published:
研究图神经网络上的数据窃取攻击问题,当前工作主要关注隐藏图上信息或者理解图神经网络的安全性和隐私问题,没有关注上游模型和下游任务pipeline整合时的隐私风险。即当上游模型提供节点表征向量时,下游可以通过这些向量反向窃取上游的图结构信息。
攻击场景
攻击者知晓节点表征向量矩阵,以及图的背景知识,比如数据来源。攻击者无法与上游模型交互,也不知道上游模型的结构和权重信息,也没有额外的数据集信息。仅能通过节点表征矩阵和背景知识对图的结构信息进行推断。攻击目标是得到一个在图的特性上与原始图类似的图结构。
攻击方法
整个攻击方法分为三步,首先是通过图的一定背景信息获取图的平均度数和规模,然后使用图标准学习(Graph Metric Learning)来学习多头注意力权重,并在每个节点嵌入矩阵的基础上裁剪距离函数。最后是利用图自编码器对图的结构进行复原,并去除错误推理出来的边。
图背景知识获取
这一步根据统一数据来源的图在结构上也表现出相似性的特点,利用统一数据来源的其他图结构来预测目标图的平均度数和规模。在确定图的来源后,使用图采样方法从统一数据来源在采样一个图,使用这个图的平均度数作为目标图的平均度数。本文使用的是spikyball sampling方法。
图标准学习
这一步的目的是通过表征向量计算两个节点之间的距离,并初始化一个图结构。图标准学习分为构建和采样两步。
构建环节使用距离函数计算每个节点对之间存在边的概率,公式如下
\[p_{uv}=e^{-\tau\delta(h_u,h_v)}\]其中\(\delta(\cdot, \cdot)\)就是距离函数,本文中开始使用了余弦距离,后为了缓解采样步骤导致的假阳性边情况,使用了一个可学习的多头注意力机制来计算节点对之间的距离,具体而言通过一下公式计算
\[\delta(u,v)=1-\frac{1}{m}\sum_{i=1}^m\cos(w_i\circ h_v, w_i\circ h_u)\]其中\(w\)是可学习的注意力权重,在后续与自编码器一起更新。
采样环节使用了Gumbel-Top-\(k\)方法,具体而言首先给概率矩阵\(P=\{p_{uv}\}\)中的每个元素都添加一个Gumbel随机扰动,然后从矩阵中选择\(k\)个概率最高的样本构成邻接矩阵\(A^0\)。
图结构复原
结构复原部分使用了图自编码器,利用第二步初始化的邻接矩阵和节点的表征矩阵作为输入,编码器是GCN,解码器是矩阵内乘。最后使用了三个自学习的图正则项来训练GAE。
图拉普拉斯正则项\(\mathcal{L}_{lap}\)假设学到的图结构应该在嵌入相似的节点之间是平滑的,即嵌入向量越相似节点越可能相连,因此定义损失如下
\[\mathcal{L}_{lap}(A^{t+1}, H^0)=\frac{1}{2n^2}\sum_{v,u}A_{vu}^{t+1}\lVert h_v^0-h_u^0\rVert=\frac{1}{2}tr(H^{0^T}L^{t+1}H^0)\]图稀疏性正则\(\mathcal{L}_{spa}\)用于保证学到的图结构符合稀疏性要求,即图中的边数不会过多。图稀疏性正则鼓励每个节点至少有一条边,但是惩罚度数过高的节点,公式如下
\[\mathcal{L}_{spa}(A^{t+1},\alpha,\beta)=-\alpha 1^T\log(A^{t+1}1)+\frac{\beta}{2}\lVert A^{t+1}\rVert\]图重构损失\(\mathcal{L}_{rec}\)就是重构的邻接矩阵与输入的邻接矩阵的二值交叉熵,用于让GAE学习的隐层表征可以有效重构输入的邻接矩阵。
\[\mathcal{L}_{rec}=\frac{1}{2n^2}\lVert A^t\circ\log(A^{t+1})+(1-A^t)\circ\log(1-A^{t+1})\rVert _F^2\]最后把三个损失相加进行联合训练,在训练过程中分两步,首先固定GAE,训练图标准学习中的注意力权重,然后固定注意力权重,训练GAE。最后训练得到的邻接矩阵会与最开始初始化的邻接矩阵相加\(A^{t+1}=(1-\eta)A^0+\eta A^{t+1}\),作为去噪函数,用于从初识的邻接矩阵中去除一些推理错误的边。然后使用修剪函数\(clip(x)=min(max(0,x),1)\)将值控制在\([0,1]\)之间,最后使用伯努利二值策略(将每个元素都作为伯努利分布的参数,然后独立采样)得到二值化的邻接矩阵,用于进行下一轮训练。
这个需要这么复杂吗,主要问题在于无法与上游模型交互的话不知道相似度多高就算有边了,原来GAE里面就是交叉熵,然后二值化就是超过0.5的就算存在,这样是不是也可以
感觉看的几篇做图上的数据窃取的(或者说用机器学习做数据窃取的)其实就是做了一个预测任务,只能说最后的结果有很高的可信度,在统计上精确度等很高,但是没法说自己就是发现了正确的结构?