资讯详情

Contrastive Search Decoding——一种对比搜索解码文本生成算法

二、代码实现理解和实验

1、代码走读

2.生成效果显示

3.方案缺陷


最近在做文本生成相关任务的时候,刷了一篇文本生成的论文:

《A Contrastive Framework for Neural Text Generation》

它认为GPT2生成模型再生成token它有不同的异向性,所以token它们之间的相似性非常接近,没有很好的区别。最终解码导致文本重复——text degeneration;因此,论文提出了一种新的训练策略(SimCTG) 解码算法(contrastive search),在多语言任务和实际工业场景中进行人工评显著提高了文本生成的质量。论文提出的text degeneration知乎上有很多大佬和论文作者讨论分析,最后得出结论text degeneration原因不是SIMCTG提出的Contrastive Training,它不能保证表征的同质性,文本生成质量(少无意义重复)之所以有真正的提高,完全来自于新提出的解码策略——contrastive search decoding。既然解码策略这么有效,就要好好学习。

一、contrastive search decoding

这是一种非topK、topP以及BeamSearch解码策略,感觉很有意思。它的核心思想是比较-比较当前要生成的token和已生成的一切token计算相似度,获得最大的相似度值;然后使它token概率与最大相似度值的差值最大化token这就是我们想要的token;具体公式如下:

V(k)是指token在模型输出的分布中top_k最有可能的结果是,通常在论文中设置K值3~10.看完公式,我觉得思想很简单,突然明白了公式要表达的思想,但还是有几个值得注意的地方:

1.如何有效地获得当前?token的embedding,也就是hv;以及如何获得h1,...ht-(已生成token的embedding)

2.如何高效计算当前?token的embedding以前所有文本embedding最大相似度值

3.如何计算整体最大值?V(k)最佳的v

当问题1已经解决时,2和3的问题更容易解决,直接使用矩阵计算GPU并行计算可以很好地解决计算效率问题;第一个问题有点难理解,不熟悉GPT对于2模型的人来说,真的不容易理解。读完源代码,和作者沟通,再加上对GPT在对生成过程的理解之后,我们完全理解如何寻求它hv的。

contrastive search decoding一般来说,解码过程如上图所示,当前轮文本输入gpt2模型,使用hm生成新的k个候选人tokens;然后把这些tokens将之前的文本拼接到下一轮模型中,以获得hm 1。这里的hm 1是上一轮应该生成的token的embedding,选择最好的解码公式hm 1也就得到了tm 1-当前轮最好的token。按照上述流程,可以生成句子。

二、二码实现理解和实验

1、代码走读

简单分析了以上核心思想。让我们来看看如何实现具体的代码使用。首先实现整体代码,然后慢慢分析:

def contrastive_search_decode(curr_input_tensor,attention_mask,tokenizer):     """     比较搜索解码策略     """     alpha = 0.5     beam_width = 5     generated = [item for item in curr_input_tensor.tolist()]     past_key_values = None      max_length = 64   curr_input_tensor.shape[1]     stop = False      with torch.no_grad():         for index in range(max_length):             if index == 0:                 inputs = prepare_inputs_for_generation(curr_input_tensor, attention_mask, past=past_key_values)                 output = model(**inputs,return_dict = True,use_cache=True,output_hidden_states=True)                 past_key_values = output.past_key_values                 last_hidden_states = output.hidden_states[-1]  # [B, S, E]                 logit_for_next_step = output.logits[:, -1, :]  # [B, V]              bsz, seqlen, embed_dim = last_hidden_states.size()              next_probs = F.softmax(logit_for_next_step, dim=-1)             _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]             top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)  # [B, K]              # compute new hidden             past_key_values = enlarge_past_key_values(past_key_values, beam_width)             output = model(                 input_ids=top_k_ids.view(-1, 1),                 attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),                 past_key_values=past_key_values,                 output_hidden_states=True,                 use_cache=True,             )             # past_key_values是一个二维list;里层list元素是tensor             past_key_values = output.past_key_values             logits = output.logits[:, -1, :]  # [B*K, V]             next_hidden = output.hidden_states[-1]  # [B*K, 1, E]             context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width,seqlen,embed_dim)  # [B*K, S, E]              selected_idx = ranking_fast(                 context_hidden,                 next_hidden,                 top_k_probs,  # [B, K]                 alpha,                 beam_width,             )  # [B]              # prepare for the next step             next_id= top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)  # [B, 1]
            temp = torch.split(next_hidden.squeeze(dim=1), beam_width)
            next_hidden = torch.stack(temp)  # [B, K, E]
            next_hidden = next_hidden[range(bsz), selected_idx, :]  # [B, E]
            last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)  # [B, S+1, E]
            past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
            temp = torch.split(logits, beam_width)
            logit_for_next_step = torch.stack(temp)[range(bsz), selected_idx, :]  # [B, V]

            tokens = next_id.squeeze(dim=-1).tolist()
            for idx, t in enumerate(tokens):
                generated[idx].append(t)

            for token in tokens:
                if token == 102:
                    stop = True
                    break
            if stop:
                break

    res = tokenizer.batch_decode(generated, skip_special_tokens=True)

说说几个细节

a、past_key_values扩充和压缩

由于每次需要传入past_key_values加快模型的推理速度,并且要在top_k中得到最佳的那个token,因此需要把K个token都要纳入计算中,为了能够矩阵计算需要把每次输入都扩充K倍:

past_key_values扩充

def enlarge_past_key_values(past_key_values, beam_width):
    # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            # item is the key and value matrix
            bsz, num_head, seq_len, esz = item.size()
            item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz)    # [bsz*beam, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

past_key_values中每个tensor的维度变化[B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]

past_key_values压缩

def select_past_key_values(past_key_values, beam_width, selected_idx):
    '''select_idx: [B]'''
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            bsz_and_beam, num_head, seq_len, esz = item.size()
            bsz = int(bsz_and_beam//beam_width)
            temp = torch.split(item, beam_width, dim=0)
            item = torch.stack(temp)    # [B, K, num_head, seq_len, esz]
            item = item[range(bsz), selected_idx, :, :, :]   # [B, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

past_key_values中每个tensor的维度从[B*K, num_head, seq_len, esz]变回到[B, num_head, seq_len, esz]

b、当前token和之前所有token的相似度并行计算

def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
    '''
        context_hidden: bsz*beam x seqlen x embed_dim
        next_hidden: bsz*beam x 1 x embed_dim
        next_top_k_probs: bsz x beam
    '''
    _, context_len, embed_dim = context_hidden.size()
    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)    # [B*K, S]
    scores, _ = torch.max(cosine_matrix, dim=-1)    # [B*K]
    next_top_k_probs = next_top_k_probs.view(-1)    # [B*K]
    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
    temp = torch.split(scores, beam_width)
    scores = torch.stack(temp)    # [B, K]
    selected_idx = scores.max(dim=-1)[1]    # [B]
    return selected_idx

需要注意到这里的torch.matmul()的计算

context_hidden:[B*K,S,D]

next_hidden:[B*K,1,D]

需要计算batch中每一条数据(每个token的embedding)和之前所有token的embedding的cos相似度

torch.matmul([B*K,S,D],B*K,1,D].T(2,1))=torch.matmul([B*K,S,D],B*K,D,1])=[B*K,S,1]

然后再求最大的那个score的index即可

2、生成效果展示

 

 生成的语句还是比较流畅的,重复性得到改善,逻辑性这个是模型本身的问题;但是具体比之前采用beamsearch + sample效果具体能好多少,这边我没有做太多的验证,需要上线使用机器人聊一段时间才知道,不过beamsearch + sample在实际使用的时候就算加上了重复惩罚系数,生成的时候也会有部分重复的,生成例子:

现在财务下班了,财务下班了,明天下午到账

不是,我们不是一个公司的,不是一个公司的

好的,那我给您改一下。那我这边给您改一下

[让我看看][让我看看][让我看看][让我看看]

代理点:506经办200019经办200019经办200019经办

2000块钱,2000块钱,2000块,2000块钱,20002000块钱,2000200020

真实的contrastive search decoding效果,还有待观察,不过目前简单的测试几条来看生成还可以。

3、方案的缺陷

一般而言,我们都要求生成的句子具有多样性——有不同的生成,contrastive search decoding是一个确定性方案,每次只能生成固定的结果。这里作者有提出一个比较合适的方法:

就是先使用beamsearch + sample等方法生成部分句子,然后再使用contrastive search decoding对生成的句子进行补齐。

具体的实现不是特别困难,这里就不实现了。

还有一种方法,实现上比较麻烦,我也提一下思想:就是那个公式中选择v的时候,不选最大的那一个,多选择几个,但是要小于K值。

公式中的argmax 换成 top_n,n取2、3、4这种比K/2小的值感觉比较合适。

参考文章:

如何评价剑桥,腾讯, DeepMind以及港大团队新作 SimCTG ? - 王琰的回答 - 知乎

2022 - A Contrastive Framework for Neural Text Generation

标签: bsz808a振动传感器变送器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

 锐单商城 - 一站式电子元器件采购平台  

 深圳锐单电子有限公司