前言

上篇谈到可以用 Dalex 来探索和解释模型的具体推理过程,这篇我们继续聊聊 AI 模型在面对歧义性偏差、对抗样本攻击和隐私泄露这些安全性方面遭遇到的挑战。

安全性

人工智能的安全性是个全新的领域,有着歧义性偏差、对抗样本攻击和隐私泄露等等风险。

歧义性偏差

在强化学习中,我们一般需要设立一个激励目标作为奖励,然后对错误的行为作出惩罚。但对奖励函数的定义偏差,在特定的场景会发生意想不到的错误,由于这种特定场景很难被发现,就会在实际生产环境下给人以"很可靠"的假象。

举个例子,在 OpenAI 的研究项目中,曾经对 CoastRunners 这个游戏做过强化学习训练。游戏的目标是快速完成赛艇比赛,最好领先于其他对手,胜负则以击中沿途的目标获得最高分数为准。

直觉上来说,小船有上下左右 4 个动作,显然击中目标的分数为第一激励,领先对手为第二激励,惩罚则是行驶到其余空白位置。

在训练 RL 模型的过程中,则发生了一些意想不到的行为。智能体找到一个孤立的循环,在那里它可以转一个大圆圈并反复击倒三个目标,并对控制运动的间隙,以便在奖励目标刷新的时候总能击中。尽管反复着火、撞到其他船、在赛道上走错路,但这个模型使用这种策略设法获得了比以正常方式完成赛程所能获得的更高的分数,这个强化学习模型获得的分数比人类玩家平均分数高出 20%。

显然,这个游戏有一个隐含的胜利目标,即必须完成比赛。虽然这只是一个小游戏,无伤大雅还很有趣,但这暴露了强化学习的一个重大隐患。我们很难,甚至完全不清楚训练的智能体想要做什么。如果沿途的奖励刷新时间或是地图稍有改变,说不定这个模型可以完美的通过测试集的验证,进而在生产环境中造成破坏。

要预防这种对激励函数设计时的歧义性偏差,OpenAI提出过几个研究方向:

  • 避免直接从场景中直接获得奖励,而是学习模仿人类的行为。

  • 除了模仿人类,还可以通过关键位置增加评估节点(比如一段时间没有前进,给予惩罚),甚至以交互方式加入人为干预。

  • 或许可以采用迁移学习的方式来训练许多类似的游戏,并推断出这个游戏的“常识”奖励函数。这种奖励可能会根据经典游戏优先完成比赛,而不是专注于该特定游戏奖励,从而更符合人类玩游戏的方式。

这些方法可能有其自身的缺点,迁移学习的方法本身可能是错误的。例如,一个受过许多赛车视频游戏训练的代理可能会错误地得出结论,认为在一个新的、更高风险的环境中开车不是什么大问题。

解决这些问题将是复杂的,现实中的目标函数同时逼近多个价值目标,这就让设计可靠激励函数的难度呈指数级增加。现阶段而言,强化学习的落地很难,更多的仿真测试,将使我们能够快速发现和解决新的故障模式,并最终开发出可以真正相信其行为的系统。

对抗样本攻击

在计算机视觉领域,用深度神经网络来做分类几乎是最佳的选择,当前大多数主流的模型识别率都能轻松超过人眼。但如果你能知道分类模型的 loss函数,则可以有针对性的设计梯度函数,从而反推生成对应的攻击样本。

这些对抗样本仅有很轻微的扰动,以至于人类视觉系统无法察觉这种扰动(图片看起来几乎一样)。这样的攻击会导致神经网络完全改变它对图片的分类。

这种对抗攻击网络无论在人脸识别领域,还是自动驾驶场景里,都有着很大的风险。你敢想象前面一辆车能被一张贴纸,就完美融于背景;或是交通转向标示被蒙上一层人类无法察觉的半透明薄膜,就能酿成车祸么...

要防御这种攻击,一般先要生成攻击样本,然后混合到训练集中,通过用识别概率最小的类别(目标类别)代替对抗扰动中的类别变量,再将原始图像减去该扰动,原始图像就变成了对抗样本,并能输出目标类别。

前几年曾参加过一个用于迷惑军事卫星识别本方军事目标的对抗赛,这里分享一下当时最终拿下亚军的竞赛方案。

  • 先用原始图像训练基准模型Xception

  • 然后用对抗神经网络扩增数据集;

  • 再用真实对战的数据集去燥后打上伪标签;

  • 接下来用迁移学习将基准模型训练为 EfficientNet模型;

  • 最后用两种尺度的图片做多模型融合输出。

其中,FGSM、图片去燥、图片多尺度、迁移学习和模型融合都能很好的防御对抗样本攻击。

不过这些需要付出高额的算力成本,很多手段只能在竞赛中使用,实际商业环境下还是面临很多挑战。

隐私保护

隐私保护则是另一个越来越受到重视的安全领域,敏感数据主要包括个人隐私信息、密码、密钥、敏感图片等高价值数据。过去国人一般对隐私不太重视,这客观上其实助力了人工智能在国内的发展,毕竟没有数据作为燃料,就没法生成优质的基础模型。

国外隐私保护则有些过度,这催生了大量数据采集人员深入了我国基层购买数据集。长远来看,其实对国家安全来说有着很高的风险,那能否既保护隐私又能共享数据,提高模型质量呢?

数据脱敏

首先最直接的方法就是对采样数据集做数据脱敏,通过对敏感数据进行不可逆的变换,又不破坏数据特征的方式来达到保护隐私的目的。

常用的脱敏算法有:

  • 哈希脱敏:不可逆,适用于密码或需要通过对比进行敏感数据确认的场景。

  • 遮盖脱敏:不可逆,适用于前端展示或敏感数据分享的场景,使用特殊符号 * 或 # 对部分文字进行遮盖实现敏感数据的脱敏。

  • 替换脱敏:部分可逆算法,适用于证件号等构成规则固定的字段脱敏。使用替换码表进行映射替换(可逆),或使用随机区间进行随机替换(不可逆),实现字段整体或者部分内容的脱敏。

  • 加密脱敏:可逆算法,适用于对需要回源的字段进行加密的场景。支持常见的对称加密算法DES, AES等等。

  • 洗牌脱敏:不可逆,适用于结构化数据列级别的数据脱敏场景。在源数据表抽取数据并确认数值范围后,对该字段(在范围内)进行列级别的打散重排和随机选择,实现混淆脱敏。

可以在数据集清洗时会使用 sklearn中的 LabelEncoderOnehotEncoder对字符串类型的特征做预处理(不可逆的替换脱敏)。

也可以在采集的图片中,应用人脸识别,将检测到的人脸局部进行模糊处理。

差分隐私保护

虽然可以通过数据脱敏来保护个人身份信息,但有些隐私信息蕴含在数据本身之中。

比如当一个银行欺诈风险类的应用或是一个医疗癌症诊断类应用,由于样本的极度不平衡,在训练某一个或某一类数据后,准确率显著提高。那大概率标明该用户属于高风险用户或是疾病患者。

差分隐私是Dwork在2006年针对统计数据库的隐私泄露问题提出的一种新的隐私定义。在此定义下,对数据库的计算处理结果对于具体某个记录的变化是不敏感的,单个记录在数据集中或者不在数据集中,对计算结果的影响微乎其微。所以,一个记录因其加入到数据集中所产生的隐私泄露风险被控制在极小的、可接受的范围内,攻击者无法通过观察计算结果而获取准确的个体信息。

这类隐性的隐私泄露很难被发现,所以可以在梯度计算时,用类似加盐的方式来对训练的loss值做出修正,这种方法就被称为差分隐私保护(differential privacy)

Tensorflow Privacy

Tensorflow Privacy (TF Privacy)Google 研究团队开发的一个开源库。该库包含一些常用 TensorFlow 优化器的实现,可用于通过 DP 来训练机器学习模型。

pip install tensorflow_privacy
import tensorflow as tf
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import (VectorizedDPKerasSGDOptimizer,
)# Select your differentially private optimizer
optimizer = VectorizedDPKerasSGDOptimizer(l2_norm_clip=l2_norm_clip,noise_multiplier=noise_multiplier,num_microbatches=num_microbatches,learning_rate=learning_rate)# Select your loss function
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.losses.Reduction.NONE)# Compile your model
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])# Fit your model
model.fit(train_data, train_labels,epochs=epochs,validation_data=(test_data, test_labels),batch_size=batch_size)

联邦学习框架

以上这些隐私保护的措施,一般只对用户有效,而对开发者可见。联邦学习 Federated Learning 则是一种分布式机器学习框架,可以在多个分散的边缘设备上训练模型,仅交换模型参数和权重,而不需要交换数据集。由于数据一直留存在用户设备上,连开发者也无法触达,则能达到物理安全的最高标准了。

这里介绍一种非常好用的联邦学习框架 Flower ,它同时支持 TersonflowPyTorch,还能支持异构的边缘设备。

Flower

https://flower.dev/

使用grpc来构成主从分布式网络,支持对包括移动设备在内的各种服务器和设备进行研究。AWSGCPAzureAndroidiOSRaspberry PiNvidia Jetson,都与 Flower 兼容。

conda create -n fed python=3.7
conda activate fed
pip install flwr
pip install torch torchvision

建立客户端

客户端的代码和平常训练模型代码差异很小,导入库文件

from collections import OrderedDict
import warningsimport flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

构建网络,读入数据,训练模型,唯一的差别在于要封装在一个Client类中

def main():"""Create model, load data, define Flower client, start Flower client."""# Load modelnet = Net().to(DEVICE)# Load data (CIFAR-10)trainloader, testloader, num_examples = load_data()# Flower clientclass CifarClient(fl.client.NumPyClient):def get_parameters(self):return [val.cpu().numpy() for _, val in net.state_dict().items()]def set_parameters(self, parameters):params_dict = zip(net.state_dict().keys(), parameters)state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})net.load_state_dict(state_dict, strict=True)def fit(self, parameters, config):self.set_parameters(parameters)train(net, trainloader, epochs=1)return self.get_parameters(), num_examples["trainset"], {}def evaluate(self, parameters, config):self.set_parameters(parameters)loss, accuracy = test(net, testloader)return float(loss), num_examples["testset"], {"accuracy": float(accuracy)}# Start client# fl.client.start_numpy_client("[::]:8080", client=CifarClient())fl.client.start_numpy_client("localhost:8080", client=CifarClient())

这里要注意,如果是Windows下,grpc不支持全局地址,需要显式的指定连接的地址,比如 localhost 或是 对应的 IP 地址。

建立服务器端

导入库文件

from typing import List, Tuple, Optional
import numpy as npimport flwr as fl
from collections import OrderedDict
import torch
from client import Net

服务器端程序,主要由 aggregate_fit 这个钩子函数来收集来自于各个训练模型的客户端消息。这里每次将汇总所有的参数,转换为对应的权重,并生成 checkpoint 模型文件,统一保存到服务器端。

DEVICE = "cuda:0"class SaveModelStrategy(fl.server.strategy.FedAvg):def aggregate_fit(self,rnd: int,results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],failures: List[BaseException],) -> Optional[fl.common.Weights]:aggregated_parameters = super().aggregate_fit(rnd, results, failures)if aggregated_parameters is not None:# Convert `Parameters` to `List[np.ndarray]`aggregated_weights: List[np.ndarray] = fl.common.parameters_to_weights(aggregated_parameters[0])# Load PyTorch modelnet = Net().to(DEVICE)# Convert `List[np.ndarray]` to PyTorch`state_dict`params_dict = zip(net.state_dict().keys(), aggregated_weights)state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})net.load_state_dict(state_dict, strict=True)# Save PyTorch `model` as `.pt`print(f"Saving round {rnd} aggregated_model...")torch.save(net, f"round-{rnd}-aggregated_model.pt")return aggregated_parameters

默认最少 2 个客户端,训练 3 轮,客户端每次训练各更新 50% 的梯度值。

if __name__ == "__main__":# Define strategystrategy = SaveModelStrategy(fraction_fit=0.5,fraction_eval=0.5,)# Start serverfl.server.start_server(# server_address="[::]:8080",server_address="localhost:8080",config={"num_rounds": 3},strategy=strategy,)

分布式训练

启动服务器程序,等待客户端的连接...

python server.py

再分别启动 2 个客户端程序

python client.py

训练结束后,就能在服务器上找到保存的 round-1-aggregated_model.pt 模型文件,完美!

源码下载

本期相关文件资料,可在公众号“深度觉醒”,后台回复:“explore02”,获取下载链接。

下一篇预告

这一篇主要介绍模型的安全性方面,并尝试避免歧义性偏差的设计,防范对抗样本攻击和用联邦学习来进行保护隐私,下一篇我们继续探讨一下人工智能的正义性问题。

浅谈AI模型的可解释性、安全性与正义性(中)相关推荐

  1. 多线程之旅之四——浅谈内存模型和用户态同步机制

     用户态下有两种同步结构的 volatile construct: 在简单数据类型上原子性的读或者写操作   interlocked construct:在简单数据类型上原子性的读和写操作 (在这里还 ...

  2. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

  3. rust怎么传送坐标_德国人怎么学电机——浅谈电机模型(十一):异步电机:绕线转子电机(一)...

    交流电机概述传送门: 善道:德国人怎么学电机--浅谈电机模型(七):交流电机概述​zhuanlan.zhihu.com 旋转磁场理论传送门: 善道:德国人怎么学电机--浅谈电机模型(八):三相交流电机 ...

  4. ad6怎么画电阻_德国人怎么学电机——浅谈电机模型(十七):同步电机(四)永磁电机(二)...

    上一章传送门: 善道:德国人怎么学电机--浅谈电机模型(十六):同步电机(三)永磁电机(一)​zhuanlan.zhihu.com 本章节如果未加说明,都以转子表面贴片的永磁电机为例. 7 电流和电压 ...

  5. pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...

  6. c++怎么确定一个整数有几位_德国人怎么学电机——浅谈电机模型(十六):同步电机(三)永磁电机(一)...

    上一章传送门: 善道:德国人怎么学电机--浅谈电机模型(十五):同步电机(二)凸极电机​zhuanlan.zhihu.com 同步电机能够类似直流电机,除了电励磁也可以直接使用永磁体来励磁.这样转子上 ...

  7. 丁小平:浅谈科学模型及突变论等问题

    作者:北京二十一世纪药理科学研究院 丁小平 科学的根本任务在于揭示规律,进而使人们可以遵循利用规律服务生产和生活.从揭示方式看,规律可以分为完成型规律和逼近型规律.所谓完成型规律,是人脑通过逻辑从有限 ...

  8. 浅谈估值模型:PB指标与剩余收益估值

    摘要及声明 1:本文简单介绍PB指标的推导以及剩余收益的估值方式: 2:本文主要为理念的讲解,模型也是笔者自建,文中假设与观点是基于笔者对模型及数据的一孔之见,若有不同见解欢迎随时留言交流: 3:笔者 ...

  9. 浅谈软件开发工具CASE在软件项目开发中发挥的作用认识

    浅谈软件开发工具CASE在软件项目开发中发挥的作用认识 内容摘要:阐述了CASE工具作为 一种开发环境在软件项目开发中所起到的开发及管理作用.CASE工具实际上是把原先由手工完成的开发过程转变为以自动 ...

最新文章

  1. mysql的聚合函数综合案例_MySQL常用聚合函数详解
  2. linux硬件设备操作函数 open
  3. 低版本mysql utf8mb5_记住:永远不要在 MySQL 中使用 UTF-8
  4. devexpress实现单元格根据条件显示不同的样式(颜色、字体、对齐方式,大小等)...
  5. std::auto_ptr简单使用
  6. 高岭土吸附阳离子_水分子在高岭土中吸附特性的蒙特卡罗模拟研究
  7. JAVA音乐网站(JAVA毕业设计)
  8. 最强PostMan使用教程(7)postman做数字签名认证
  9. wincc 服务器授权型号,WINCC 授权详解
  10. mysql事务应该多复杂_可能是全网最好的MySQL重要知识点/面试题总结
  11. Android---性能优化方案分享,高级android开发强化实战pdf
  12. 手把手教你给 SSH 启用二次身份验证
  13. Gale-Shapley 算法 寻找稳定婚配java实现
  14. 韦东山嵌入式第一期学习笔记DAY_1——2_0_安装ubuntu16.04虚拟机
  15. python字典存储省份与城市_python实现城市和省份字典(根据城市判断属于哪个省份)...
  16. 全志H3-NanoPi开发板SDK之一总体概述
  17. 2.5D的ACT类型游戏碰撞检测
  18. 芝加哥大学计算机语言学,芝加哥大学cs专业值得申请么?
  19. MT4白标升级主标的方法
  20. QPBOC之GPO(一):CVM处理

热门文章

  1. 面试mysql之SQL优化总结一:索引的使用
  2. allgro pcb铜皮编辑_关于修割铜皮 - Cadence allegro PCB 教程
  3. hmmbuild结果文件解读:hmm文件
  4. 12自由度六足机器人实现步态规划功能
  5. coco训练集darknet_darknet-yolov3训练自己的数据集
  6. js、css 实现table表头固定
  7. 框架、架构和设计模式?!
  8. Vue + Element UI 表格分页记忆选中
  9. R语言|forest plot
  10. 物联网通信技术|课堂笔记|week8|网络安全学习|加密逻辑|加密算法