代码链接:PPLM_code

二. Example command for discriminator based sentiment control

python run_pplm.py -D sentiment --class_label 2 --cond_text “My dog died” --length 50 --gamma 1.0 --num_iterations 10 --num_samples 10 --stepsize 0.04 --kl_scale 0.01 --gm_scale 0.95 --sample

代码逻辑——PPLM_Discrim:

1. run_pplm_example():

2. 加载模型,词典;对模型参数进行freeze

3. full_text_generation():返回生成结果:无扰动句子和扰动后句子;

3.1 generate_text_pplm(无扰动):目的是生成无扰动时的文本;
3.1.1:for i in range_func:每次更新一个词,设置句子长度30
  1. 如果输入3个词,这步只用前2个词,返回hidden

  2. 无扰动pert_past = past

  3. 传入last,和pert_past;返回pert_logits, pert_past, pert_all_hidden;进行past更新,加入当前last信息

  4. 根据pert_logits取出topk的词;

  5. 取出概率最大的词,当做下一个词的last,加入列表output_so_far

3.2:这样只采样生成5句扰动后句子:

3.2.1 generate_text_pplm(加入扰动)
3.2.1.1 for i in range_func:每次更新生成一个词
  1. 如果输入3个词,这步只用前2个词,返回hidden
  2. 输入全部三个词,返回无扰动:unpert_logits, unpert_past, unpert_all_hidden
  3. accumulated_hidden = unpert_last_hidden[:, :-1, :]; 只考虑前k-1个词;accumulated_hidden 再求和
  4. perturb_past():返回pert_past
4.1.初始化一个为0的grad_accumulator,size和24个k,v一起的一样(2,1,16,2,64)
4.2.初始化一个window_mask,size同(2,1,16,2,64)
4.3.进行三轮的梯度迭代:4.3.1. 把前面轮累计的梯度 grad_accumulator 加到 past 得到 perturbed_past; past 是已经更新后不进行梯度计算4.3.2. 把最近生成的词last 和 perturbed_past传入模型;得到all_logits和 all_hidden4.3.3. 从all_hidden取出根据last 生成的hidden 加到 accumulated_hidden 且不进行更新4.3.4. 取出根据last 从模型中生成的 logits 和probs;(1, 50257)++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++4.3.5. 计算PPLM_DISCRIM loss4.3.5.1 初始化交叉熵loss,   取出前面无扰动时,输入全部初始词得到的unpert_past;4.3.5.2 for _ in range(horizon_length):  # horizon_length是14.3.5.3. 把new_accumulated_hidden 除以 (curr_length + 1 + horizon_length)后传入分类器进行一个mlp 预测; 从(1,1024) -> (1,5)4.3.5.4. 构建交叉熵label(作者训练一个mlp作为分类器)4.3.5.5. 计算 discrim_loss()++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++4.3.6. 根据unpert_probs 和 probs 计算 kl-loss4.3.7. 两个loss相加进行反向传播4.3.8. 用window_mask对curr_perturbation (实际上就是grad_accumulator ) 计算梯度norms ; 然后step size 计算 grad4.3.9. 把计算得grad 加到 grad_accumulator4.3.10. 用newpast 去取出pastnei
  1. 传入模型last,和pert_past;返回pert_logits, pert_past, pert_all_hidden;进行past更新,加入当前last信息
  2. 融合扰动后概率和扰动前概率;再取出top_k的词;rescale;得到新词典的概率分布
  3. 取出概率最大的词,当做下一个词的last
    print无扰动句子和扰动后的句子;

4. print无扰动句子和扰动后的句子;

代码如下:

#! /usr/bin/env python3
# coding=utf-8
# Copyright 2018 The Uber AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License."""
Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95Example command with discriminator:
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
"""import argparse
import json
from operator import add
from typing import List, Optional, Tuple, Unionimport numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModelfrom pplm_classification_head import ClassificationHeadPPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10QUIET = 0
REGULAR = 1
VERBOSE = 2
VERY_VERBOSE = 3
VERBOSITY_LEVELS = {'quiet': QUIET,'regular': REGULAR,'verbose': VERBOSE,'very_verbose': VERY_VERBOSE,
}BAG_OF_WORDS_ARCHIVE_MAP = {'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt",'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt",'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
}DISCRIMINATOR_MODELS_PARAMS = {"clickbait": {"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt","class_size": 2,"embed_size": 1024,"class_vocab": {"non_clickbait": 0, "clickbait": 1},"default_class": 1,"pretrained_model": "gpt2-medium",},"sentiment": {# "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",'path' : '/home/xps/huanghong/workdir/PPLM-master/cache/SST_classifier_head.pt',"class_size": 5,"embed_size": 1024,"class_vocab": {"very_positive": 2, "very_negative": 3},"default_class": 3,"pretrained_model": "gpt2-medium",},
}def to_var(x, requires_grad=False, volatile=False, device='cuda'):if torch.cuda.is_available() and device == 'cuda':x = x.cuda()elif device != 'cuda':x = x.to(device)return Variable(x, requires_grad=requires_grad, volatile=volatile)def top_k_filter(logits, k, probs=False):"""Masks everything but the k top entries as -infinity (1e10).Used to mask logits such that e^-infinity -> 0 won't contribute to thesum of the denominator."""if k == 0:return logitselse:values = torch.topk(logits, k)[0]batch_mins = values[:, -1].view(-1, 1).expand_as(logits)if probs:return torch.where(logits < batch_mins,torch.ones_like(logits) * 0.0, logits)return torch.where(logits < batch_mins,torch.ones_like(logits) * -BIG_CONST,logits)def perturb_past(past,model,last,unpert_past=None,unpert_logits=None,accumulated_hidden=None,grad_norms=None,stepsize=0.01,one_hot_bows_vectors=None,classifier=None,class_label=None,loss_type=0,num_iterations=3,horizon_length=1,window_length=0,decay=False,gamma=1.5,kl_scale=0.01,device='cuda',verbosity_level=REGULAR
):# Generate inital perturbed past  # shape同24个block,k 和v  stack后的值grad_accumulator = [(np.zeros(p.shape).astype("float32"))for p in past]if accumulated_hidden is None:accumulated_hidden = 0if decay:decay_mask = torch.arange(0.,1.0 + SMALL_CONST,1.0 / (window_length))[1:]else:decay_mask = 1.0# TODO fix this comment (SUMANTH)# Generate a mask is gradient perturbated is based on a past window_, _, _, curr_length, _ = past[0].shapeif curr_length > window_length and window_length > 0:ones_key_val_shape = (tuple(past[0].shape[:-2])+ tuple([window_length])+ tuple(past[0].shape[-1:]))zeros_key_val_shape = (tuple(past[0].shape[:-2])+ tuple([curr_length - window_length])+ tuple(past[0].shape[-1:]))ones_mask = torch.ones(ones_key_val_shape)ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)ones_mask = ones_mask.permute(0, 1, 2, 4, 3)window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)),dim=-2).to(device)else:window_mask = torch.ones_like(past[0]).to(device)  # 2,1,16,2,64# accumulate perturbations for num_iterations  累计三次迭代的扰动loss_per_iter = []new_accumulated_hidden = Nonefor i in range(num_iterations):if verbosity_level >= VERBOSE:print("Iteration ", i + 1)curr_perturbation = [to_var(torch.from_numpy(p_), requires_grad=True, device=device)for p_ in grad_accumulator]# Compute hidden using perturbed pastperturbed_past = list(map(add, past, curr_perturbation))_, _, _, curr_length, _ = curr_perturbation[0].shapeall_logits, _, all_hidden = model(last, past=perturbed_past) # 加入past  则 k,v变成 (1,16,64,3), 而q(1,16,1,64)hidden = all_hidden[-1] # 最后一block输出的hiddennew_accumulated_hidden = accumulated_hidden + torch.sum(  # accumulated_hidden 每个词最后一层输出的hidden相加hidden,dim=1).detach()# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)logits = all_logits[:, -1, :]probs = F.softmax(logits, dim=-1)loss = 0.0loss_list = []if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:for one_hot_bow in one_hot_bows_vectors:bow_logits = torch.mm(probs, torch.t(one_hot_bow))  # 1 x vocab.size  矩阵乘  vocab.size x 149  linearbow_loss = -torch.log(torch.sum(bow_logits)) # 149个分数相加,然后取-torch.log, 因为每个值小于1,所以log后为负,所以取-logloss += bow_lossloss_list.append(bow_loss)if verbosity_level >= VERY_VERBOSE:print(" pplm_bow_loss:", loss.data.cpu().numpy())if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM:ce_loss = torch.nn.CrossEntropyLoss()# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)curr_unpert_past = unpert_past # 根据前面无扰动时,输入全部初始词生成的;curr_probs = torch.unsqueeze(probs, dim=1)wte = model.resize_token_embeddings()for _ in range(horizon_length):inputs_embeds = torch.matmul(curr_probs, wte.weight.data) # (1, num_vocab) x (num_vocab, 1024) todo 1,1024表示的是什么含义; 作用也相当于一个mlp_, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past,inputs_embeds=inputs_embeds  # inputs_embeds看作是当前词对词典打分,具体到用这个打分从词表中乘每个词的一个维度;得到一个词的embed)curr_hidden = curr_all_hidden[-1]new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)prediction = classifier(new_accumulated_hidden /(curr_length + 1 + horizon_length)) # todo 为什么是new_accumulated_hidden除 (curr_length + 1 + horizon_length)label = torch.tensor(prediction.shape[0] * [class_label], # 加载的分类模型,是由二分类训练来device=device,dtype=torch.long)discrim_loss = ce_loss(prediction, label)if verbosity_level >= VERY_VERBOSE:print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())loss += discrim_lossloss_list.append(discrim_loss)kl_loss = 0.0if kl_scale > 0.0:unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)  # 取最后一个词映射到词典的概率分布,1,1,50256unpert_probs = (unpert_probs + SMALL_CONST *(unpert_probs <= SMALL_CONST).float().to(device).detach()  # detach表示新变量requires_grad为false,unpert_probs <= SMALL_CONST返回值是0或1; 目的是如果给某个词的打分小于10e-15则在词典中把原来的分数+10e-15;是未为了不使词典中给某个词打分过低)correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(  # probs是通过 输入最新的词得到的对词典的分布device).detach()corrected_probs = probs + correction.detach()kl_loss = kl_scale * (  # kl_scale * (b * log(b / a))(corrected_probs * (corrected_probs / unpert_probs).log()).sum())if verbosity_level >= VERY_VERBOSE:print(' kl_loss', kl_loss.data.cpu().numpy())loss += kl_lossloss_per_iter.append(loss.data.cpu().numpy())if verbosity_level >= VERBOSE:print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())# compute gradientsloss.backward()# calculate gradient norms  计算梯度范数,这里是L1范数if grad_norms is not None and loss_type == PPLM_BOW:grad_norms = [       # 最大梯度范数进行归一化torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))for index, p_ in enumerate(curr_perturbation)]else:grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST)for index, p_ in enumerate(curr_perturbation)]# normalize gradientsgrad = [-stepsize *(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()for index, p_ in enumerate(curr_perturbation)]# accumulate gradientgrad_accumulator = list(map(add, grad, grad_accumulator))# reset gradients, just to make surefor p_ in curr_perturbation:p_.grad.data.zero_()# removing past from the graphnew_past = []for p_ in past:new_past.append(p_.detach())past = new_past# apply the accumulated perturbations to the pastgrad_accumulator = [to_var(torch.from_numpy(p_), requires_grad=True, device=device)for p_ in grad_accumulator]pert_past = list(map(add, past, grad_accumulator))return pert_past, new_accumulated_hidden, grad_norms, loss_per_iterdef get_classifier(name: Optional[str],class_label: Union[str, int],device: str,verbosity_level: int = REGULAR
) -> Tuple[Optional[ClassificationHead], Optional[int]]:if name is None:return None, Noneparams = DISCRIMINATOR_MODELS_PARAMS[name]classifier = ClassificationHead(class_size=params['class_size'],embed_size=params['embed_size']).to(device)if "url" in params:resolved_archive_file = cached_path(params["url"])elif "path" in params:resolved_archive_file = params["path"]else:raise ValueError("Either url or path have to be specified ""in the discriminator model parameters")classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))classifier.eval()if isinstance(class_label, str):if class_label in params["class_vocab"]:label_id = params["class_vocab"][class_label]else:label_id = params["default_class"]if verbosity_level >= REGULAR:print("class_label {} not in class_vocab".format(class_label))print("available values are: {}".format(params["class_vocab"]))print("using default class {}".format(label_id))elif isinstance(class_label, int):if class_label in set(params["class_vocab"].values()):label_id = class_labelelse:label_id = params["default_class"]if verbosity_level >= REGULAR:print("class_label {} not in class_vocab".format(class_label))print("available values are: {}".format(params["class_vocab"]))print("using default class {}".format(label_id))else:label_id = params["default_class"]return classifier, label_iddef get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \List[List[List[int]]]:bow_indices = []for id_or_path in bag_of_words_ids_or_paths:if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])else:filepath = id_or_pathwith open(filepath, "r") as f:words = f.read().strip().split("\n")bow_indices.append(  # word2idx[tokenizer.encode(word.strip(),add_prefix_space=True,add_special_tokens=False)for word in words])return bow_indicesdef build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):if bow_indices is None:return Noneone_hot_bows_vectors = []for single_bow in bow_indices:single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) # 确保词袋中每个词长度小于1single_bow = torch.tensor(single_bow).to(device)num_words = single_bow.shape[0]one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)  # 149, 50257  todo:为啥是建立这样一个矩阵,?one_hot_bow.scatter_(1, single_bow, 1) # dim=1,表示列变,single_bow是149x1所以这里表示,比如第0个词的数值表示为k, 表示把第0行第k列置为1one_hot_bows_vectors.append(one_hot_bow)return one_hot_bows_vectorsdef full_text_generation(model,tokenizer,context=None,num_samples=1,device="cuda",bag_of_words=None,discrim=None,class_label=None,length=100,stepsize=0.02,temperature=1.0,top_k=10,sample=True,num_iterations=3,grad_length=10000,horizon_length=1,window_length=0,decay=False,gamma=1.5,gm_scale=0.9,kl_scale=0.01,verbosity_level=REGULAR,**kwargs
):classifier, class_id = get_classifier(discrim,class_label,device)bow_indices = []if bag_of_words:  # load bag of wordbow_indices = get_bag_of_words_indices(bag_of_words.split(";"),tokenizer)if bag_of_words and classifier:loss_type = PPLM_BOW_DISCRIMif verbosity_level >= REGULAR:print("Both PPLM-BoW and PPLM-Discrim are on. ""This is not optimized.")elif bag_of_words:loss_type = PPLM_BOWif verbosity_level >= REGULAR:print("Using PPLM-BoW")elif classifier is not None:loss_type = PPLM_DISCRIMif verbosity_level >= REGULAR:print("Using PPLM-Discrim")else:raise Exception("Specify either a bag of words or a discriminator")unpert_gen_tok_text, _, _ = generate_text_pplm(    # todo:这里这个函数目的是什么? unpert_gen_tok_text是原始3个词 +生成无扰动的50个词,对比model=model,tokenizer=tokenizer,context=context,device=device,length=length,sample=sample,perturb=False,verbosity_level=verbosity_level)if device == 'cuda':torch.cuda.empty_cache()  #pert_gen_tok_texts = []discrim_losses = []losses_in_time = []for i in range(num_samples):pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(model=model,tokenizer=tokenizer,context=context,device=device,perturb=True,bow_indices=bow_indices,classifier=classifier,class_label=class_id,loss_type=loss_type,length=length,stepsize=stepsize,temperature=temperature,top_k=top_k,sample=sample,num_iterations=num_iterations,grad_length=grad_length,horizon_length=horizon_length,window_length=window_length,decay=decay,gamma=gamma,gm_scale=gm_scale,kl_scale=kl_scale,verbosity_level=verbosity_level)pert_gen_tok_texts.append(pert_gen_tok_text)if classifier is not None:discrim_losses.append(discrim_loss.data.cpu().numpy())losses_in_time.append(loss_in_time)if device == 'cuda':torch.cuda.empty_cache()return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_timedef generate_text_pplm(model,tokenizer,context=None,past=None,device="cuda",perturb=True,bow_indices=None,classifier=None,class_label=None,loss_type=0,length=100,stepsize=0.02,temperature=1.0,top_k=10,sample=True,num_iterations=3,grad_length=10000,horizon_length=1,window_length=0,decay=False,gamma=1.5,gm_scale=0.9,kl_scale=0.01,verbosity_level=REGULAR
):output_so_far = Noneif context:context_t = torch.tensor(context, device=device, dtype=torch.long)while len(context_t.shape) < 2:context_t = context_t.unsqueeze(0)output_so_far = context_t# collect one hot vectors for bags of words   # one-hot 为149x50256, 每一行一个bag词one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,device)grad_norms = Nonelast = Noneunpert_discrim_loss = 0loss_in_time = []if verbosity_level >= VERBOSE:range_func = trange(length, ascii=True)else:range_func = range(length)for i in range_func:# Get past/probs for current output, except for last word# Note that GPT takes 2 inputs: past + current_token# run model forward to obtain unperturbedif past is None and output_so_far is not None:last = output_so_far[:, -1:]      # get the last word in contextif output_so_far.shape[1] > 1:_, past, _ = model(output_so_far[:, :-1])  # todo 这里为什么只输入前两个词?,24个k,v stack在一起的结果unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) # todo 这里又为什么输入前3个词;unpert_logits是映射到词典的结果,unpert_past是24个k,v stack在一起的结果, unpert_all_hidden:24个block embed后的x + 最后的block的输出unpert_last_hidden = unpert_all_hidden[-1] # 最后block的输出# check if we are abowe grad max lengthif i >= grad_length:  # 如果到最大长度,step_size设置为0current_stepsize = stepsize * 0else:current_stepsize = stepsize# modify the past if necessaryif not perturb or num_iterations == 0:pert_past = pastelse:accumulated_hidden = unpert_last_hidden[:, :-1, :]    # todo  只考虑前k-1个词 输入模型时的x; 1x3x1024,accumulated_hidden = torch.sum(accumulated_hidden, dim=1)  # 三个词中的前2个求和;1x2x1024 ———>  1x1024if past is not None:pert_past, _, grad_norms, loss_this_iter = perturb_past(past, # 只根据前两个词生成;k,v stackmodel,last,unpert_past=unpert_past, # 根据前三个词生成;k,v stackunpert_logits=unpert_logits, # # 根据前三个词生成的词典分布accumulated_hidden=accumulated_hidden, # # 把三个词输入模型前x的前2个求和grad_norms=grad_norms,stepsize=current_stepsize,one_hot_bows_vectors=one_hot_bows_vectors,classifier=classifier,class_label=class_label,loss_type=loss_type,num_iterations=num_iterations,horizon_length=horizon_length,window_length=window_length,decay=decay,gamma=gamma,kl_scale=kl_scale,device=device,verbosity_level=verbosity_level)loss_in_time.append(loss_this_iter)else:pert_past = pastpert_logits, past, pert_all_hidden = model(last, past=pert_past)  # pert_logits是last经过24层映射到词典,past是key 和value  stack的 cache, pert_all_hidden每层输入前的x和最后一个输出hiddencepert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONSTpert_probs = F.softmax(pert_logits, dim=-1)if classifier is not None:ce_loss = torch.nn.CrossEntropyLoss()prediction = classifier(torch.mean(unpert_last_hidden, dim=1))label = torch.tensor([class_label], device=device,dtype=torch.long)unpert_discrim_loss = ce_loss(prediction, label)if verbosity_level >= VERBOSE:print("unperturbed discrim loss",unpert_discrim_loss.data.cpu().numpy())else:unpert_discrim_loss = 0# Fuse the modified model and original model  融合修改的模型和原始模型if perturb:unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)pert_probs = ((pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)))  # + SMALL_CONSTpert_probs = top_k_filter(pert_probs, k=top_k,probs=True)  # + SMALL_CONST# rescaleif torch.sum(pert_probs) <= 1:pert_probs = pert_probs / torch.sum(pert_probs)else:pert_logits = top_k_filter(pert_logits, k=top_k)  # + SMALL_CONST  # 把词典中小于第十大概率的词设置为负无穷pert_probs = F.softmax(pert_logits, dim=-1)# sample or greedyif sample:last = torch.multinomial(pert_probs, num_samples=1)  # 取样词典中概率最大词下标,更新lastelse:_, last = torch.topk(pert_probs, k=1, dim=-1)# update context/output_so_far appending the new tokenoutput_so_far = (last if output_so_far is Noneelse torch.cat((output_so_far, last), dim=1))print(tokenizer.decode(last))if tokenizer.decode(last) == '.':breakif verbosity_level >= REGULAR:print(tokenizer.decode(output_so_far.tolist()[0]))return output_so_far, unpert_discrim_loss, loss_in_timedef set_generic_model_params(discrim_weights, discrim_meta):if discrim_weights is None:raise ValueError('When using a generic discriminator, ''discrim_weights need to be specified')if discrim_meta is None:raise ValueError('When using a generic discriminator, ''discrim_meta need to be specified')with open(discrim_meta, 'r') as discrim_meta_file:meta = json.load(discrim_meta_file)meta['path'] = discrim_weightsDISCRIMINATOR_MODELS_PARAMS['generic'] = metadef run_pplm_example(pretrained_model="gpt2-medium",cond_text="",uncond=False,num_samples=1,bag_of_words=None,discrim=None,discrim_weights=None,discrim_meta=None,class_label=-1,length=100,stepsize=0.02,temperature=1.0,top_k=10,sample=True,num_iterations=3,grad_length=10000,horizon_length=1,window_length=0,decay=False,gamma=1.5,gm_scale=0.9,kl_scale=0.01,seed=0,no_cuda=False,colorama=False,verbosity='regular'
):# set Random seedtorch.manual_seed(seed)np.random.seed(seed)# set verbosiry todo:verbosiry是什么意思verbosity_level = VERBOSITY_LEVELS.get(verbosity.lower(), REGULAR)# set the devicedevice = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"if discrim == 'generic':set_generic_model_params(discrim_weights, discrim_meta)if discrim is not None:discriminator_pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]if pretrained_model != discriminator_pretrained_model:pretrained_model = discriminator_pretrained_modelif verbosity_level >= REGULAR:print("discrim = {}, pretrained_model set ""to discriminator's = {}".format(discrim, pretrained_model))# load pretrained modelmodel = GPT2LMHeadModel.from_pretrained('/home/xps/huanghong/workdir/PPLM-master/cache/', # pretrained_model,output_hidden_states=True)model.to(device)model.eval() # 这是测试代码# load tokenizertokenizer = GPT2Tokenizer.from_pretrained('/home/xps/huanghong/workdir/PPLM-master/cache/')# Freeze GPT-2 weightsfor param in model.parameters():param.requires_grad = False# figure out conditioning textif uncond:tokenized_cond_text = tokenizer.encode([tokenizer.bos_token],add_special_tokens=False)else:raw_text = cond_textwhile not raw_text:print("Did you forget to add `--cond_text`? ")raw_text = input("Model prompt >>> ")tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text,add_special_tokens=False)print("= Prefix of sentence =")print(tokenizer.decode(tokenized_cond_text))print()# generate unperturbed and perturbed texts# full_text_generation returns:# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_timeunpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(model=model,tokenizer=tokenizer,context=tokenized_cond_text,device=device,num_samples=num_samples,bag_of_words=bag_of_words,discrim=discrim,class_label=class_label,length=length,stepsize=stepsize,temperature=temperature,top_k=top_k,sample=sample,num_iterations=num_iterations,grad_length=grad_length,horizon_length=horizon_length,window_length=window_length,decay=decay,gamma=gamma,gm_scale=gm_scale,kl_scale=kl_scale,verbosity_level=verbosity_level)# untokenize unperturbed textunpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])if verbosity_level >= REGULAR:print("=" * 80)print("= Unperturbed generated text =")print(unpert_gen_text)print()generated_texts = []bow_word_ids = set()if bag_of_words and colorama:bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),tokenizer)for single_bow_list in bow_indices:# filtering all words in the list composed of more than 1 token  each step only choose one word from vocabfiltered = list(filter(lambda x: len(x) <= 1, single_bow_list))# w[0] because we are sure w has only 1 item because previous fitlerbow_word_ids.update(w[0] for w in filtered)# iterate through the perturbed textsfor i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):try:# untokenize unperturbed textif colorama:import coloramapert_gen_text = ''for word_id in pert_gen_tok_text.tolist()[0]:if word_id in bow_word_ids:pert_gen_text += '{}{}{}'.format(colorama.Fore.RED, # Fore是针对字体颜色,tokenizer.decode([word_id]),colorama.Style.RESET_ALL  # Style是针对字体格式, Back是针对字体背景颜色)else:pert_gen_text += tokenizer.decode([word_id])else:pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])print("= Perturbed generated text {} =".format(i + 1))print(pert_gen_text)print()except:pass# keep the prefix, perturbed seq, original seq for each indexgenerated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))returnif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument("--pretrained_model","-M",type=str,default="gpt2-medium",help="pretrained model name or path to local checkpoint",)parser.add_argument("--cond_text", type=str, default="The lake",help="Prefix texts to condition on")parser.add_argument("--uncond", action="store_true",help="Generate from end-of-text as prefix")parser.add_argument("--num_samples",type=int,default=1,help="Number of samples to generate from the modified latents",)parser.add_argument("--bag_of_words","-B",type=str,default=None,help="Bags of words used for PPLM-BoW. ""Either a BOW id (see list in code) or a filepath. ""Multiple BoWs separated by ;",)parser.add_argument("--discrim","-D",type=str,default=None,choices=("clickbait", "sentiment", "toxicity", "generic"),help="Discriminator to use",)parser.add_argument('--discrim_weights', type=str, default=None,help='Weights for the generic discriminator')parser.add_argument('--discrim_meta', type=str, default=None,help='Meta information for the generic discriminator')parser.add_argument("--class_label",type=int,default=-1,help="Class label used for the discriminator",)parser.add_argument("--length", type=int, default=100)parser.add_argument("--stepsize", type=float, default=0.02)parser.add_argument("--temperature", type=float, default=1.0)parser.add_argument("--top_k", type=int, default=10)parser.add_argument("--sample", action="store_true",help="Generate from end-of-text as prefix")parser.add_argument("--num_iterations", type=int, default=3)parser.add_argument("--grad_length", type=int, default=10000)parser.add_argument("--window_length",type=int,default=0,help="Length of past which is being optimized; ""0 corresponds to infinite window length",)parser.add_argument("--horizon_length",type=int,default=1,help="Length of future to optimize over",)parser.add_argument("--decay", action="store_true",help="whether to decay or not")parser.add_argument("--gamma", type=float, default=1.5)parser.add_argument("--gm_scale", type=float, default=0.9)parser.add_argument("--kl_scale", type=float, default=0.01)parser.add_argument("--seed", type=int, default=0)parser.add_argument("--no_cuda", action="store_true", help="no cuda")parser.add_argument("--colorama", action="store_true",help="colors keywords")parser.add_argument("--verbosity", type=str, default="very_verbose",choices=("quiet", "regular", "verbose", "very_verbose"),help="verbosiry level")args = parser.parse_args()run_pplm_example(**vars(args))  # vars返回对象的属性和属性值的字典对象

PLUG AND PLAY LANGUAGE MODELS: A SIMPLE APPROACH TO CONTROL LEDTEXT(PPLM):代码深入理解(二)—PPLM_Discrim相关推荐

  1. PLUG AND PLAY LANGUAGE MODELS: A SIMPLE APPROACHTOCONTROLLEDTEXT(PPLM):代码深入理解(一)—Bag-Of-Words

    代码链接:PPLM_code 一. Example command for bag-of-words control: python run_pplm.py -B military --cond_te ...

  2. Paper:GPT-3《 Language Models are Few-Shot Learners》的翻译与解读

    Paper:GPT-3< Language Models are Few-Shot Learners>的翻译与解读 目录 <GPT-3: Language Models are Fe ...

  3. 多模态 Generalized Visual Language Models

    点击上方"迈微AI研习社",选择"星标★"公众号 重磅干货,第一时间送达 多年来,人们一直在研究处理图像以生成文本,例如图像字幕和视觉问答.传统上,此类系统依赖 ...

  4. LLMs之InstructGPT:《Training language models to follow instructions with human feedback》翻译与解读

    LLMs之InstructGPT:<Training language models to follow instructions with human feedback>翻译与解读 导读 ...

  5. AIGC之LLaMA:《LLaMA: Open and Efficient Foundation Language Models》翻译与解读

    AIGC之LLaMA:<LLaMA: Open and Efficient Foundation Language Models>翻译与解读 导读:该论文提出了一个开源的大规模语言模型LL ...

  6. Paper:GPT-3之《 Language Models are Few-Shot Learners》的翻译与解读

    Paper:GPT-3之< Language Models are Few-Shot Learners>的翻译与解读 目录 <GPT-3: Language Models are F ...

  7. 论文阅读9-Fine-tuning Pre-Trained Transformer Language Models to(远程监督关系抽取,ACL2019,GPT,长尾关系,DISTRE)

    文章目录 abstrac 1.Introduction 2 Transformer Language Model 2.1 Transformer-Decoder 2.2 Unsupervised Pr ...

  8. Chapter9 : De Novo Molecular Design with Chemical Language Models

    reading notes of<Artificial Intelligence in Drug Design> 文章目录 1.Introduction 2.Materials 2.1.C ...

  9. Paying More Attention to Self-attention: Improving Pre-trained Language Models via Attention Guiding

    更加关注自注意力:通过注意力引导改进预训练语言模型 Shanshan Wang Shandong University Qingdao, China wangshanshan5678@gmail.co ...

最新文章

  1. 赢在中国(08-02-27)
  2. plsql查找不到带中文的纪录
  3. 如何正确入门Windows系统下驱动开发领域?
  4. mysql必学十大必会_MYSQL 学习(一)--启蒙篇《MYSQL必知必会》
  5. WEBGL的测试网站和基础知识
  6. list vue 添加数据方法_一篇文章教会你创建vue项目和使用vue.js实现数据增删改查...
  7. 【ElasticSearch】ElasticSearch 中使用衰减函数来完美你的搜索结果
  8. vue添加滚动事件,解决简书Carol_笑一笑方案中vue移除滚动事件失效的问题
  9. cmake的一些小经验
  10. Linux.CommanlineTool.grep
  11. 【读书笔记《Android游戏编程之从零开始》】9.游戏开发基础(如何快速的进入 Android 游戏开发)
  12. 《高质量程序设计指南——C++/C》(第三版)
  13. 全面解析用户故事地图
  14. 解决IDEA导入项目后无法运行的问题
  15. HDU 1880魔咒词典
  16. 怎么查看电脑的电池损耗情况?
  17. java 从txt写入txt
  18. Java,第一次作业——六边形面积
  19. 《自控力》直面自身欲望,但不要付诸行动
  20. 使用Angular和API服务器显示相关表中的数据

热门文章

  1. C#笔试面试宝典值得收藏1
  2. 基于Thinkphp5+EasyWeChat+fastadmin微信小程序授权登录获取手机号微信公众号网页---联合授权登录
  3. Java实验(18) 幻灯片放映
  4. php设计验证码图片,PHP图片验证码制作实现分享(全)
  5. 使用opencv-python制作屏幕录制工具
  6. matlab烟花动图程序,用PS制作漂亮烟花绽放效果的GIF图片
  7. 三相变频电源整流有什么特征?
  8. linux中su 与su-的区别
  9. .[转] 全国主体功能区规划图
  10. 乔巴机器人 番外篇_乔巴机器人五只合体图+10个小乔巴+【附合体动图】