Mini Survey of Machine Unlearning

1 minute read

Published:

Machine Unlearning 相关笔记

是在阅读Machine Unlearning of Features and Labels和Graph Unlearning之后的笔记,在阅读了其他机器遗忘相关的论文后会扩展。

机器遗忘背景和研究价值

目前的机器学习方法利用用户的信息进行训练,然后在推荐,搜索,广告等场景为用户提供个性化的结果。利用用户的敏感信息训练的模型可能在刻意设计的攻击下泄露这些信息,或者在推理时表现出某些危险的行为。场景包括:

  1. 为符合隐私法规的要求,需要按照用户要求对于某些敏感数据,或者某些用户的数据对于模型的影响进行消除。
  2. 大模型的无意识记忆(记住训练数据中的某些输入并在推理中复现的现象)可能会泄露敏感信息,需要对无意识记忆的内容进行消除。
  3. 大模型的学习结果可能会有某些不符合价值观/道德标准的东西,在发现时也需要消除这部分训练数据的影响。

由此出现了隐私保护和阻止信息泄露的需求。最简单直接的方法自然是将需要遗忘的数据从训练数据中删除,然后重新训练模型————这也是目前对于遗忘效果进行数学描述的对比基准————但这种方法的问题也很明显:很多时候数据量或模型规模过大,重新训练成本很高;存储成本原因很多训练数据并不会无限保留;某些数据处于不断更新中,不是完全可用。

其实还有一种情况论文里没提,就是很多预训练模型的用户是无法进行重训练的,一方面由于根本无法接触模型参数和预训练的数据,另一方面由于无法承担模型的预训练成本。

因此需要一种高效,快速的方法来将指定数据的影响从模型中剔除的方法,让模型“忘记”指定数据的信息,也就是模型遗忘的研究内容。

和对抗攻击以及中毒攻击的区别: 从效果上说,对抗攻击或者中毒攻击有和机器遗忘类似的削弱数据影响的地方。在根本上说,这些攻击方法和机器遗忘的目标是相悖的:机器遗忘需要在去掉指定数据的影响后尽可能保留模型推理精度,但对抗攻击和中毒方法则以尽可能降低目标模型精度为目标。因此直接使用对抗或中毒攻击方法的结果很可能是以毁掉模型的代价遗忘指定的数据。

和模型微调的区别: 根据machine unlearning of feature and labels的描述,机器遗忘的目标与微调的目标有细微区别,微调的目标在于使模型的预测贴近微调数据的标签,而遗忘的目标在于使模型的预测贴近对冲的目标外远离原有数据的预测结果。

图取消学习是机器取消学习的扩展,专门针对图结构数据而设计。它涉及完全删除已删除数据的所有痕迹,例如图中的节点和其所有连接的边。图遗忘的目标是从图中删除特定的节点和边以及其他信息,同时保留其整体结构。这种级别的数据删除涉及数据点之间的相连关系,因而比传统的机器遗忘要复杂得多。

图结构数据的复杂性对图的遗忘提出了挑战。图中节点和边之间的关系错综复杂,形成了一个密集的互连信息网络。随着数据隐私法规和个人控制其数据的权利变得越来越重要,从图结构数据中全面删除敏感信息的能力成为至关重要的道德保障。确保有效、彻底地消除图表中被遗忘的数据可以让个人更好地控制自己的信息,增强对机器学习系统的信心,并维护负责任的数据管理的道德标准。

目前进展

按照看过的论文,目前的机器遗忘方法大概可以分为精确遗忘和近似遗忘两种。

分片法继承重训练的思想,并对训练成本过高的问题进行改进。改进方法是将训练数据分片(shard),在每个片上分别训练子模型,最终将每个子模型的结果进行聚合得到模型的推理结果。在得到遗忘需求时确认需要遗忘的数据所在分片,在片上去除对应数据后重新训练对应的子模型。总体来说是通过限制重训练的范围来提升训练效率,降低训练成本。 Graph Unlearning针对图数据设计了基于社团发现和图嵌入聚类的两种分片方式避免随机划分导致的模型效果恶化,同时限制每个分片的规模来避免数据规模不均衡导致的训练成本不均,最终使用自适应的聚合方式得到模型的最终分类结果。在遗忘时修改对应分片的数据(删除节点&边),然后重训练对应的子模型。

抵消法从影响函数的角度出发,设计对冲的样本加入训练,从而抵消目标数据对于模型的影响。

todo:看Machine Unlearning of Features and Labels的参考文献9和21,看这个影响函数是什么东西,要不看不懂这个损失函数。

影响函数方法的基本原理

整个方法从遗忘的目标出发:使遗忘后的模型尽可能与完全重训的模型一致。因此理想状态下,遗忘算法修改后的模型参数\(w^-\)应当与用剩余数据\(\mathcal{D}'\)完全重训的模型一样,即满足:

\[w^-=argmin_w L(w;\mathcal{D}')\]

当损失函数收敛时遗忘完成,得到遗忘后的模型,此时满足损失函数梯度为0,即

\[\nabla L(w;\mathcal{D}')=0\]

使用一阶泰勒展开处理上式,得到

\[\nabla L(w;\mathcal{D}')\approx \nabla L(w^*;\mathcal{D}')+\nabla^2L(w^*;\mathcal{D}')(w^--w^*)=0\]

其中\(w^*\)是遗忘前的模型参数,最终化简就得到

\[w^-=w^*-\overbrace{\nabla^2L(w^*;\mathcal{D}')}^{H_{w^*}^{-1}}\overbrace{\nabla L(w^*;\mathcal{D}')}^{\Delta}\]

其中\(I=H_{w^*}^{-1}\Delta\)称为影响函数。

因为原来的模型是在完整数据上的最优模型,并且样本之间没有关联关系,所以有

\[\begin{aligned} \nabla L(w^*;\mathcal{D}')&=\nabla L(w^*;\mathcal{D})-\nabla L(w^*;z_n) \\ &=0-\nabla L(w^*;z_n) \\ &=-\nabla\mathcal{l}((w^*)^Tx_n,y_n)-\lambda w \end{aligned}\]

所以在机器遗忘论文里将其中的负号提出来跟影响函数中的负号抵消,直接将在遗忘数据处的损失表示为

\[\Delta=\lambda w^*+\nabla\mathcal{l}((w^*)^Tx_n,y_n)\]

最终参数更新方式为

\[w^-=w^*+H_{w^*}^{-1}\Delta\]

对于图上的遗忘问题而言,最大的区别在于样本之间是有关联的,一个数据点的删除会导致其他数据点的预测结果也发生变化。因此\(\nabla L(w^*;\mathcal{D}')\)会发生变化,不会只是将要遗忘的样本的梯度减掉了。根据完全重训的思路,这个变化应当是所有样本的损失的变化之和,即certified graph unlearning里面提到的

\[\underbrace{\nabla L(w^*;\mathcal{D})}_{original\space prediction}-\underbrace{\nabla L(w^*;\mathcal{D}')}_{retrain\space prediction}\]

再加上要删除的样本点在\(\mathcal{D}'\)里没有,所以整个\(\Delta\)部分可以写为

\[\Delta=\lambda w^*+\nabla\mathcal{l}((w^*)^Tx_n,y_n)+\sum_{i=1}^{n-1}(\nabla\mathcal{l}((w^*)^Tx_i,y_i)-\nabla\mathcal{l}((w^*)^Tx'_i,y_i))\]

即可以将遗忘节点和其他节点的影响在传播路径处进行分离,从而完成对于遗忘节点信息的对冲操作。

潜在方向

对于Graph Unlearning这篇论文,使用分片的方式可能不是一个最好的方法,因为分片训练会在一定程度导致模型精度降低:究其根本依然是限制信息的传播范围来降低重训练成本。使用抵消法设计对冲样本进行训练可能是更好的方法。

对于机器遗忘这个方向而言,目前的方法都需要预先知道需要删除的数据样本是什么。这在需求1的场景中很容易获得,需求2的场景中可能比较容易定位,但是在需求3的场景中很难找到所有需要删除的数据样本,寻找过程本身也很依赖人的主观判断。其次如前文所言,当初的训练数据因为存储成本或者更新已经无法获得。所以需要一种不需要提前获得目标数据就可以进行有效遗忘的方法。

这种方法的效果如何衡量,以及在不精确知道需要遗忘的数据的时候如何去遗忘数据?

以及论文中所提的,已有方法在面对大量需要遗忘的数据的时候(大量的数据样本/特征/标签)普遍存在效率低下的问题。分片法的原因很直观,其底层依然是重训练方法,减少训练成本的分片方法在大量数据需要重训练时并没有作用:所有片都需要重训练的话那就是从零重训练。抵消法的原因还需要分析。

相关工作

2020

  1. 【ICML】Certified Data Removal from Machine Learning Models
    1. 机器遗忘问题,主要研究损失函数全部可微的线性模型上的遗忘,也扩展了一下到非线性模型
    2. 提出了可验证遗忘:从数据中移除的模型不能从从未观察过数据开始的模型中区分出来。与差分隐私的定义类似
    3. 为使用可微凸损失函数训练的L2正则化线性模型开发一种可验证遗忘机制,通过计算遗忘样本点处的Hessian值和损失梯度,使用牛顿更新遗忘机制进行一次牛顿更新完成模型的遗忘过程。

2022

  1. 【ACL】In-Context Unlearning: Language Models As Few Shot Unlearners
    1. in-context leanring的machine unlearning方法,在inference阶段的in-context prompt这里加入对要遗忘的样本的翻转标签作为example input,然后让模型进行推理
    2. 文章中说的是最后再加个query input让模型做其他推理任务,但其实如果要验证的话不就是问这个forgotten example,那就等于没过模型直接在prompt里找的,治标不治本
  2. 【NeurIPS】Certified Graph Unlearning
    1. 也是对SGC这样的线性GNN进行的分析,目前还没看出和GraphEditor的区别
    2. 根据删除的数据计算梯度更新方向,然后直接更新梯度

2023

  1. 【ICCV】SAFE: Machine Unlearning With Shard Graphs
    1. 机器学习遗忘问题,分片重训式的遗忘方法,在最小化期望成本的同时,在多样化的数据集上适应大型模型,以消除训练样本对已训练模型的影响。
    2. 并不是图上的遗忘问题,而是利用图对结构数据上的遗忘方法进行改进。已有方法增加分片的数量减少了遗忘的期望成本,但同时增加了推理成本,并且由于在独立模型训练过程中丢失了样本之间的协同信息,降低了模型的最终精度。
    3. SAFE引入了分片图的概念,允许在训练过程中结合来自其他分片的有限信息,以牺牲期望遗忘代价的小幅度增加和准确率的显著提高为代价,同时仍然实现了遗忘后剩余影响的完全去除。
    4. 不完全针对图数据,使用类别平衡下采样方法得到粗分片,然后再粗分片中划分限定规模的细分片,根据图的数据使用关系构建分片图连边。在每个分片图上训练一个模型,遗忘时根据要求重训要遗忘的数据所在的分片图上的模型,以及所有使用这个分片数据的分片图上的模型。
  2. 【AISTATS】Efficiently Forgetting What You Have Learned in Graph Representation Learning via Projection
    1. 线性GNN上的节点遗忘问题,提出精确遗忘方法projector,在节点特征处将权重参数投影到一个与想要遗忘的节点特征不相关的子空间来遗忘一组节点特征
    2. 如果权重参数也在节点特征的线性生成空间中(也是节点特征的线性组合),那权重的梯度仍然是节点特征的线性组合,因此可以从与遗忘节点的特征无关的子空间中寻找一个权重组合,从而完成遗忘
    3. 方法首先利用节点特征计算正交投影的系数,然后计算遗忘后的权重参数。projector可以看作对剩余的节点进行重新加权,使得模型的行为尽可能地接近遗忘之前的模型,而不携带任何关于遗忘节点特征的信息。因此projector可以看作在某种重要性采样下的重训练方法。
  3. 【WWW】Unlearning Graph Classifiers with Limited Data Resources
    1. 图散射变换(GST)的遗忘问题,遗忘场景是多图中的单个图中节点遗忘
    2. 因为图散射变换是不可训练的,最后使用了一个可训练的分类器完成图分类,节点特征通过图散射变换得到,最后遗忘的是这个图分类器。
    3. 遗忘方式与可验证遗忘类似,也是通过hessian值和损失梯度直接计算影响,然后通过一次牛顿更新完成参数调整
  4. 【WWW】GIF: A General Graph Unlearning Strategy via Influence Function
    1. 图神经网络的遗忘问题,探索为图遗忘量身定制的影响函数
    2. 现有的工作要么采用再训练范式,要么进行近似遗忘,没有考虑连接邻居之间的相关性,或者对GNN结构施加约束,因此难以实现令人满意的性能-复杂度权衡。
    3. 传统的影响函数假设剩余节点在遗忘前后的预测是一致的,所以所有损失的总和就是原来所有节点的损失和减去要遗忘的节点的损失和,但在图中对于遗忘节点的移除会影响多跳内的剩余节点。
    4. 本文遵循传统影响函数的数据扰动策略,将图的结构依赖性纳入参数改变量的考虑范围,推导出图上适用的影响函数。根据图上不同的遗忘任务(节点遗忘,边遗忘,特征遗忘)确立不同的影响范围,并计算出对应的参数改变量
  5. 【USENIX Security】Inductive Graph Unlearning
    1. 和graph unlearning一样是分片重训类型的方法,区别在于新的划分方式,一个节点相邻节点特征重建方法,以及最后根据图核计算相似度的加权聚合方法
    2. 计算k个碎片,应用修复函数恢复边,训练k个模型并聚合。
    3. 针对推理式方法的改进应该就是这个相邻节点特征重建以及最后的相似度加权聚合方法
  6. 【ICLR】GraphEditor: An Efficient Graph Representation Learning and Unlearning Approach
    1. 线性GNN上的机器遗忘问题,提出了GraphEditor,一种高效的图表示学习和去学习方法,支持线性GNN的节点/边删除、节点/边添加和节点特征更新。
    2. 考虑非凸设定,计算GNN输出与标签的闭式解来更新模型。方法先求出闭式解,在遗忘时首先删除要遗忘的节点和所有相关节点的信息,然后再将不用删除的相关节点的信息更新回来(这两步都给了闭式解,只要明确了相关节点的表征就可以计算)
  7. 【ICLR】Fast Yet Effective Graph Unlearning through Influence Analysis
    1. 图神经网络上的机器遗忘问题,针对边遗忘问题提出了一种使用影响函数的直接更新方法
    2. 基础思路是使用影响函数确认删除特定边之后的影响范围,然后通过提升权重(upweight)的方式消除对于模型参数的影响
    3. 影响函数通过原图中每个节点损失的二阶梯度(Hessian matrix),剩余图上节点的损失梯度,原图中每个节点损失的梯度来计算。最后在原模型上一次性减去影响函数值来更新模型
  8. 【ICLR】Efficient Model Updates For Approximate Unlearning of Graph-Structured Data
    1. SGC和GPR等线性GNN上的机器遗忘问题,给出了进行遗忘的理论基础合格遗忘(certified removal),指遗忘后的模型性能与完全重训的模型性能需要满足差分隐私的要求,要近似一致
    2. 首先介绍了三种图去学习的数据删除请求:节点特征遗忘、边遗忘和节点遗忘。然后推导了SGC上所有3种移除情况的近似图遗忘机制及其GPR推广的理论保证,给出了三种任务下的梯度残差范数上界。遗忘方式与可验证遗忘的方式相同,通过牛顿更新遗忘机制完成。
    3. 遗忘节点的度在遗忘过程中起着重要作用,而传播步数对于不同的遗忘场景可能重要,也可能不重要。
  9. 【ICICS】Graph Unlearning Using Knowledge Distillation
    1. 使用仅在保留数据上训练的教师网络来指导蒸馏目标学生网络,通过教师网络使用的数据集隔离要遗忘的数据,等于重训一个轻量级的模型来蒸馏目标的重量级模型
  10. 【Arxiv】Distill to Delete: Unlearning in Graph Networks with Knowledge Distillation
    1. 图神经网络的遗忘问题,提出使用知识蒸馏来进行遗忘
    2. 提出遗忘的目标是连贯性(在要遗忘的数据集上和完全重训的模型有类似的结果)、完整性(在保留的数据集上和完全重训的模型有类似的结果)
    3. 使用了两个模型推理进行蒸馏,destroyer模型通过支持中性或负面知识来遗忘目标知识,preserver模型通过支持正面的知识或相连节点的嵌入来保留知识。destroyer模型负责通过遗忘数据集来达成连贯性目标,preserver模型负责通过保留数据集来达成完整性目标
  11. 【ICLR】GNNDELETE: A General Strategy For Unlearning In Graph Neural Networks
    1. 图神经网络上的机器遗忘问题,提出了一种使用外部模块的遗忘方式
    2. 攻击场景应当是白盒场景,操作方能够接触训练数据和模型结构,能够添加模块对模型进行继续训练
    3. 提出了GNN上遗忘的两个目标:删除边的一致性(需要删除的边的两个端点在遗忘后被判断相连的概率应当与随机两个节点预判相连的概率接近)和邻域影响性(遗忘前后节点的表征应当足够接近)
    4. 设计了一个外部模块用于进行遗忘,在每个GNN层都添加一个同样的模块,如果节点属于遗忘边的端点则使用可学习参数进行遗忘,如果不属于则不进行操作
    5. 这个方法需要能够接触数据和模型,感觉相较已有方法对于场景的要求是最严格的,但好处是不吃模型,什么GNN模型都可以添加
    6. 以及可以看看它这个随机挑两个节点的是怎么实现的
  12. 【WWW】Heterogeneous Federated Knowledge Graph Embedding Learning and Unlearning
    1. 联邦学习场景下的知识图谱嵌入遗忘问题,提出了联邦学习场景下的知识图谱学习和遗忘方法
    2. 服务器-客户端框架,每轮训练服务器向目标客户端发送上轮节点嵌入,客户端在本地完成新一轮训练,最后将结果发送至服务器进行加权更新
    3. 提出了一个两步的遗忘操作,第一步追溯干扰使用要遗忘数据节点的负样本构建硬混淆(优化局部嵌入),利用负样本和遗忘数据之间的距离构建软混淆(反映了遗忘集合中三元组的分数与其负样本之间的距离),并结合本地和全局模型的距离进行遗忘训练。全局嵌入的干扰损失也以类似的方式计算,通过将局部得分改为全局得分并反向蒸馏。第二步被动衰减使用全局模型和局部模型互相作为老师进行知识蒸馏,抑制遗忘三元组的激活。

2024

  1. 【AAAI】Towards Effective and General Graph Unlearning via Mutual Evolution
    1. 使用mutual evolution的图上遗忘方法,老实说没太看懂
    2. 使用了一个预测模型和一个遗忘模型来共同进化,预测模型用于调整目标GNN模型同时保留推理能力,遗忘模型用于基于预测模块为非遗忘实体生成预测,同时为模型调整提供遗忘能力
    3. 最终是将两个模块进行了联合训练,但文中一部分公式没有看懂,一些符号的含义没有明确说明