除了用于Google地图的ETA预测,GNN还可以做什么?

发布时间:2021-09-20 07:00

摘要

 

近年来,新兴起的图神经网络在很多应用领域都取得了非常出色的表现,如今年用于Google 地图的到达时间估计(Estimated Time of Arrival,ETA),在纽约、洛杉矶、东京、新加坡等国际大都市都获得了很大的提升,该结果对其他地区也具有通用性[1]。图神经网络以图结构为核心组成部分,这与结构因果模型有着相似的结构形式。鉴于此,DeepMind最新的研究工作[2]以图神经网络为网络结构,设计了一种基于图神经网络的变分图自编码器,用于近似Pearl因果层次结构中的因果计算问题。与以往用因果推断思想提升深度学习性能不同的是,该研究工作在图神经结构与结构因果模型之间建立了转换机制,为Pearl因果层次结构中的因果计算提供了一种新型计算方法和思路,是深度学习在因果推断领域应用的一项开创性的尝试性工作。

 

结构因果模型(SCM)

 

结构因果模型(Structural Causal Model,SCM)用于描述现实世界关联特征及其相互作用,是一种能够形式化表述数据背后因果假设的方法。结构因果模型含有两个变量集以及一组函数

 

 

即函数根据模型中其他变量的值给变量赋值。若存在于的定义域中,则变量是变量的直接原因。中的变量称为外生变量,属于模型的外部,不必解释其变化的原因。中的变量为内生变量,模型中的每一个内生变量都至少有一个外生变量作为直接原因。

 

图1:SCM实例和Pearl Causal Hierarchy

 

每一个结构因果模型都对应一个图模型中的每个变量都表示为一个节点,对于变量,如果的定义域中含有变量,那么在中,会有一条从的有向边。同时,通过实例化外生变量,也可以生成包含所有内生变量的数据,因此SCM也是一种数据生成模型。在SCM中,可以回答Pearl因果层次结构(Pearl Causal Hierarchy,PCH)关于基于观察数据(L1)、干预(L2)以及反事实(L3)三个层次的因果计算, 如图1所示。在PCH因果计算中,其核心思想是需要解决干预(L2)的问题,这是回答反事实问题(L3)的基础。在SCM中,干预主要体现在do-演算操作,即将干预变量强制设置为某一固定值,使得干预变量不会随其他变量影响。下面以一个例子来说明do-演算在SCM中的计算过程。

Example:考虑饮食(D)对血压(B)的影响,表示高纤维的饮食习惯,表示高血压,假设给定的SCM为M,外生变量,内生变量以及函数

 

 

其中⊕是XOR逻辑操作,该SCM相对应的图结构如图2所示。

 

图2:Example的图结构

 

此外,因为所有变量都取二元值,因此我们也可以枚举推演出其真值表,如表1所示。 

 

表1:Example的真值表

 

现在我们对变量D进行干预,设置为,那么根据表1我们可以计算的概率,就是将表1中所有蓝色行的数字求和,即

 

 

图结构是SCM模型的一个重要组成部分,能直观地表达变量之间的交互信息。但是在实际问题中,这种交互关系的确定往往严重依赖于领域专家知识,无可避免地引入了人为误差。个人认为,GNN作为深度学习方法的衍生体,能有效地近似任何函数,是模拟SCM中概率分布的一个可行方法。

 

基于SCM的图神经网络干预

 

Do-演算也可在图上以更直观的方式呈现。当对变量进行干预时,意味着削弱了该变量响应其他变量而变化的自然趋势。在SCM对应的图结构表示上,就需要删除指向该变量的所有边,如图3中右边红色锤子对应的边。

 

 

图3:干预变量Y的图结构表示

 

按照上述思想,文[2]对GNN也定义了类似的干预操作,主要体现在GNN的消息传递(message-passing)中。在标准的GNN信息聚合操作中,图中节点通过聚合其父节点的信息完成当前节点的信息更新,如式(2)所示。但是在干预的GNN中,如果当前节点为干预变量,则忽略其父节点的信息(将替换为),如式(3)所示。

 

 

干预变分图自编码器

 

在上述图神经网络(Graph Neural Networks, GNN)的do-演算基础上,文[2]定义了用于近似PCH因果推断的干预变分图自编码器(Interventional Variational Graph Auto-Encoder, iVGAE),如图4所示。

 

图4:iVGAE结构图

 

乍一看,图4上部分描述的是标准变分图自编码器模型结构。但为了能近似SCM在L2层次的因果推断,文[2]将编码器函数、解码器函数都设计为以给定SCM为图结构的GNN的聚合函数。在进行L2层次因果推理时,根据给定的查询变量和干预变量,动态地对图结构进行do-演算调整,即忽略/不计算来自干预变量的父节点的信息,而模型的输出可近似成该干预变量下的概率分布,即完成L2层次的因果计算。在训练iVGAE,主要采用了变分方法,其中目标函数也需要考虑干预变量。

 

思考

 

与以往用因果推断改进深度学习方法效果不同的是,文[2]侧重于用基于GNN的深度学习来完成SCM中的PCH因果计算,侧重于基于观察数据(L1)、干预的推断(L2)。由于图神经网络与SCM都是基于图结构,一种简单、直接的方法就是在给定SCM图结构上,设计一种合适参数转化机制,以确保SCM和深度学习模型表达同个分布,这也是文[2]的主要设计思路。同时,文[2]也指出,SCM需要对每个变量都定义各自相应的映射函数。相反的,在iVGAE中,可以找到单个共享聚合函数,用于聚合图中所有节点的消息。然而将单个聚合函数转换成多个结构方程的优化过程是异常困难,而这也是实现反事实推理需要解决的问题,这也是文[2]没有考虑L3层次推理的一个原因。

虽然文[2]、[3]都试图在SCM与深度学习之间建立联系,目前主要侧重于将深度学习看成一种近似方法来完成PCH中的因果计算。当然,对PCH因果计算的支持是实现因果推断的重要内容,也可以看成深度学习在因果表达上迈出了重要的一步。不同方法有不同程度的兼容性,如文[2]不支持L3层次的计算。这些研究也引出了更深层次的问题,如基于神经网络的因果计算优势体现在哪里,例如,推理计算是否更高效?文中尚未提供明确的答案。

除了近似分布,也有将深度学习用于因果发现中的研究工作,如文[4]中提出了连续优化(continuous optimization)的思想,重新定义了因果图发现的一种求解方式。与其在图空间进行搜索,转化为寻找一个包含图结构的邻接矩阵的函数,从而可以使用深度学习方法进行梯度下降求解。当然,这与文[2]有着不同的研究目标。不可否认,如何使得深度学习和因果推断相得益彰,是一个非常值得探索的方向,相信在不久的将来两者能碰触更多的火花。

最后,个人觉得文[2]的亮点在于采用现流行的GNN模型来模拟SCM的数据生成机制,虽然这种数据生成过程是一种黑盒子方法(这也是深度学习广为争议的一个特征),但如果仅从数据模拟效果的角度来看,未尝不可?

 

由于水平有限,文中存在不足的地方,请各位读者批评指正,也欢迎大家参与我们的讨论。

 

参考文献

 

[1] Austin Derrow-Pinion,  Jennifer She,  David Wong, et al. ETA Predictionwith Graph Neural Networks in Google Maps 2021

[2] Matej Zecevi,Devendra Singh Dhami, Petar Velickovi,Kristian Kersting. Relating Graph Neural Networks to Structural Causal Models 2021

[3] Kevin Xia, Kai-Zhan Lee, Yoshua Bengio, Elias Bareinboim. The Causal-Neural Connection: Expressiveness, Learnability, and inference 2021 

[4] Xun Zheng, Bryon Aragam, Pradeep Ravikumar, Eric P Xing. Dags with no tears: continuous optimization for structure learning 2018 

上一个: 神经渲染中的特色深度计算特征

下一个: 神经渲染最新进展与算法(二):NeRF及其演化

近期文章

通用AI模型的未来:深度强化学习(deep reinforcement learning)

近年来,AI模型开始涌现出超越人类的潜力,在传统的围棋游戏以及拥有复杂规则和系统的电竞游戏(星际争霸2,Dota 2等)中都有体现。随着ChatGPT的出现,人们开始意识到语言模型成为通用人工智能的可能性,而这些模型的核心都是深度强化学习。

2023-05-08

“为了全人类更好地交流”:通用语音识别

人机交互的第一步往往由人发起,从你说出第一句话开始,计算机如何可以应答并能进而和你自然畅谈?在之前文章里我们分享了当前利用AI进行语音识别的关键步骤和做法,这次我们将随着技术的进化发展,领略当前通用语音模型的进化高度。

2023-04-24

查看更多