【自然语言处理】【Pytorch】从头实现SimCSE_pytorch simcse-程序员宅基地

技术标签: pytorch  自然语言处理  对比学习  BERT  文本表示  

相关博客:
【自然语言处理】【对比学习】SimCSE:基于对比学习的句向量表示
【自然语言处理】BERT-Whitening
【自然语言处理】【Pytorch】从头实现SimCSE
【自然语言处理】【向量检索】面向开发域稠密检索的多视角文档表示学习
【自然语言处理】【向量表示】AugSBERT:改善用于成对句子评分任务的Bi-Encoders的数据增强方法
【自然语言处理】【向量表示】PairSupCon:用于句子表示的成对监督对比学习

github:

Concise_SimCSE

博客内容:基于pytorch和transformers,从头开始实现SimCSE
import torch
import torch.nn as nn

from abc import ABC
from tqdm.notebook import tqdm
from dataclasses import dataclass, field
from typing import List, Union, Optional, Dict
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, TrainingArguments, Trainer
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions

一、定义参数

@dataclass
class DataArguments:
    train_file: str = field(default="./data/simcse/wiki1m_for_simcse.txt",
                            metadata={
    "help": "The path of train file"})
    model_name_or_path: str = field(default="E:/pretrained/bert-base-uncased",
                                    metadata={
    "help": "The name or path of pre-trained language model"})
    max_seq_length: int = field(default=32,
                                metadata={
    "help": "The maximum total input sequence length after tokenization."})


training_args = TrainingArguments(
        output_dir="./checkpoints",
        num_train_epochs=1,
        per_device_train_batch_size=64,
        learning_rate=3e-5,
        load_best_model_at_end=True,
        overwrite_output_dir=True,
        do_train=True,
        do_eval=False,
        logging_steps=10)


data_args = DataArguments()

二、读取数据

# 初始化tokenizer
tokenizer = BertTokenizer.from_pretrained(data_args.model_name_or_path)
# 读取训练数据
with open(data_args.train_file, encoding="utf8") as file:
    texts = [line.strip() for line in tqdm(file.readlines())]
print(type(texts))
print(texts[0])
<class 'list'>
YMCA in South Australia

三、构建Dataset和collate_fn

3.1 构建Dataset

class PairDataset(Dataset):
    def __init__(self, examples: List[str]):
        total = len(examples)
        # 将所有样本复制一份用于对比学习
        sentences_pair = examples + examples
        sent_features = tokenizer(sentences_pair,
                                  max_length=data_args.max_seq_length,
                                  truncation=True,
                                  padding=False)
        features = {
    }
        # 将相同的样本放在同一个列表中
        for key in sent_features:
            features[key] = [[sent_features[key][i], sent_features[key][i + total]] for i in tqdm(range(total))]
        self.input_ids = features["input_ids"]
        self.attention_mask = features["attention_mask"]
        self.token_type_ids = features["token_type_ids"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, item):
        return {
    
            "input_ids": self.input_ids[item],
            "attention_mask": self.attention_mask[item],
            "token_type_ids": self.token_type_ids[item]
        }

train_dataset = PairDataset(texts)
print(train_dataset[0])
{'input_ids': [[101, 26866, 1999, 2148, 2660, 102], [101, 26866, 1999, 2148, 2660, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]}

3.2 构建collate_fn

@dataclass
class DataCollator:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        special_keys = ['input_ids', 'attention_mask', 'token_type_ids']
        batch_size = len(features)
        if batch_size == 0:
            return
        # flat_features: [sen1, sen1, sen2, sen2, ...]
        flat_features = []
        for feature in features:
            for i in range(2):
                flat_features.append({
    k: feature[k][i] for k in feature.keys() if k in special_keys})
        # padding
        batch = self.tokenizer.pad(
            flat_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        # batch_size, 2, seq_len
        batch = {
    k: batch[k].view(batch_size, 2, -1) for k in batch if k in special_keys}
        return batch

collate_fn = DataCollator(tokenizer)
dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=collate_fn)
batch = next(iter(dataloader))
print(batch.keys())
print(batch["input_ids"].shape)
dict_keys(['input_ids', 'attention_mask', 'token_type_ids'])
torch.Size([4, 2, 32])

四、构建模型

# 全连接层,用于投影CLS的向量表示
class MLPLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size)
        self.activation = nn.Tanh()

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = self.activation(x)
        return x

# 相似度层,计算向量间相似度
class Similarity(nn.Module):
    def __init__(self, temp):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp

    
# SimCSE的完整模型结构
class BertForCL(BertPreTrainedModel, ABC):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.mlp = MLPLayer(config.hidden_size, config.hidden_size)
        self.sim = Similarity(temp=0.05)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                sent_emb=False):
        if sent_emb:
            # 模型推断时使用的forward
            return self.sentemb_forward(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        token_type_ids=token_type_ids,
                                        position_ids=position_ids,
                                        head_mask=head_mask,
                                        inputs_embeds=inputs_embeds,
                                        labels=labels,
                                        output_attentions=output_attentions,
                                        output_hidden_states=output_hidden_states,
                                        return_dict=return_dict)
        else:
            # 模型训练时使用的forward
            return self.cl_forward(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   token_type_ids=token_type_ids,
                                   position_ids=position_ids,
                                   head_mask=head_mask,
                                   inputs_embeds=inputs_embeds,
                                   labels=labels,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict)

    def sentemb_forward(self,
                        input_ids=None,
                        attention_mask=None,
                        token_type_ids=None,
                        position_ids=None,
                        head_mask=None,
                        inputs_embeds=None,
                        labels=None,
                        output_attentions=None,
                        output_hidden_states=None,
                        return_dict=None):
        # 1.使用bert进行编码
        outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=True)
        # 2.取cls的表示
        cls_output = outputs.last_hidden_state[:, 0]
        # 3.使用MLP进行投影
        cls_output = self.mlp(cls_output)
        # 返回
        if not return_dict:
            return (outputs[0], cls_output) + outputs[2:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            pooler_output=cls_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )

    def cl_forward(self,
                   input_ids=None,
                   attention_mask=None,
                   token_type_ids=None,
                   position_ids=None,
                   head_mask=None,
                   inputs_embeds=None,
                   labels=None,
                   output_attentions=None,
                   output_hidden_states=None,
                   return_dict=None):
        # input_ids: batch_size, num_sent, len
        batch_size = input_ids.size(0)
        num_sent = input_ids.size(1)  # 2
        # 1. 重塑输入张量的形状,使其满足bert对输入的要求
        # input_ids: batch_size * num_sent, len
        input_ids = input_ids.view((-1, input_ids.size(-1)))
        attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
        # 2. 使用bert进行编码
        outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=True)
        # 3. 取cls的向量表示
        cls_output = outputs.last_hidden_state[:, 0]
        # 4. 重塑形状
        cls_output = cls_output.view((batch_size, num_sent, cls_output.size(-1)))
        # 5. 全连接层投影
        # batch_size, num_sent, 768
        cls_output = self.mlp(cls_output)
        # 6. 将同一批样本的两次向量表示分开
        z1, z2 = cls_output[:, 0], cls_output[:, 1]
        # 7. 计算两两相似度,得到相似度矩阵cos_sim
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))
        # 8. 生成标签[0,1,...,batch_size-1],该标签用于提高相似度句子cos_sim对角线,并降低非对角线
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(cos_sim, labels)

        if not return_dict:
            output = (cos_sim,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

model = BertForCL.from_pretrained(data_args.model_name_or_path)
cl_out = model(**batch, return_dict=True)
print(cl_out.keys())
odict_keys(['loss', 'logits'])

五、模型训练

model.resize_token_embeddings(len(tokenizer))
trainer = Trainer(model=model,
                  train_dataset=train_dataset,
                  args=training_args,
                  tokenizer=tokenizer,
                  data_collator=collate_fn)
trainer.train()
trainer.save_model("models/test")
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/bqw18744018044/article/details/119336466

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签