infersent :学习通用句子表示
该存储库包含用于监督NLI任务的Pytorch实现和实验界面,该任务具有不同的模型,用于学习通用句子表示。
结果
使用SNLI数据对NLI任务培训了基线模型MeanEmbedding和三个基于LSTM的模型LSTM , BiLSTM和BiLSTM-maxpool 。使用Senteval框架对8个转移任务进行评估句子嵌入。
按照infersent论文的第5节的定义,计算了Senteval任务的微标准和宏观度量[1]。结果在下面列出:
| 模型 | snli-dev | SNLI测试 | Senteval-Micro | Senteval-Macro |
|---|---|---|---|---|
| 含义 | 69.5 | 69.1 | 77.31 | 77.92 |
| LSTM | 80.5 | 80.2 | 70.467 | 70.282 |
| 比尔斯特 | 80.00 | 80.08 | 71.997 | 71.531 |
| Bilstm-Maxpool | 86.50 | 85.87 | 79.075 | 78.831 |
组织
该存储库被组织成以下主要组成部分:
-
models.py编码器和分类器模型的Pytorch模块。 -
data.pySNLIData类,用于准备用于培训和评估的数据。 -
train.py与不同编码器进行培训的Pytorch Lightning模型和培训CLI。 -
eval.pyCLI,使用模型检查点并在SNLI和SenteVal任务上运行评估。 -
demo.ipynb用于测试模型推理的Jupyter笔记本和分析结果。
设置
# Using pip pip install -r requirements.txt # Using conda conda env create -f environment.yml # Download english model for SpaCy tokenizer python -m spacy download en_core_web_sm
要使用Senteval进行评估,请按以下方式准备SenteVal安装:
git clone https://*github.**com/facebookresearch/SentEval.git cd SentEval/ && python setup.py install # Download datasets cd SentEval/data/downstream/ && ./get_transfer_data.bash
训练
运行train.py具有以下编码器类型之一: MeanEmbedding , LSTM , BiLSTM , BiLSTM-maxpool 。培训过程将创建./logs目录中的模型检查点,张板日志和HyperParams文件hparams.yaml 。
python train.py --encoder_type= \' BiLSTM \'
评估
使用模型检查点标志运行eval.py ,以在SNLI和SenteVal上运行评估任务。
python eval.py --checkpoint_path= \' ./logs/MeanEmbedding/version_0/checkpoints/epoch=2-step=12875.ckpt \'
预训练的模型
模型检查点和张板日志是公开的,可以在此处找到:https://drive.google.com/drive/folders/1ebjyf0wj31ezmpebig1nhw-1jomml1iy?usp = sharing
参考
[1] A. Conneau,D。Kiela,H。Schwenk,L。Barrault,A。Bordes,监督从自然语言推理数据中学习通用句子表示的
