时间:2021-05-23
Embedding
词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就相当于是一个大矩阵,矩阵的每一行表示一个单词。
emdedding初始化
默认是随机初始化的
import torchfrom torch import nnfrom torch.autograd import Variable# 定义词嵌入embeds = nn.Embedding(2, 5) # 2 个单词,维度 5# 得到词嵌入矩阵,开始是随机初始化的torch.manual_seed(1)embeds.weight# 输出结果:Parameter containing:-0.8923 -0.0583 -0.1955 -0.9656 0.4224 0.2673 -0.4212 -0.5107 -1.5727 -0.1232[torch.FloatTensor of size 2x5]如果从使用已经训练好的词向量,则采用
pretrained_weight = np.array(args.pretrained_weight) # 已有词向量的numpyself.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))embed的读取
读取一个向量。
注意参数只能是LongTensor型的
# 访问第 50 个词的词向量embeds = nn.Embedding(100, 10)embeds(Variable(torch.LongTensor([50])))# 输出:Variable containing: 0.6353 1.0526 1.2452 -1.8745 -0.1069 0.1979 0.4298 -0.3652 -0.7078 0.2642[torch.FloatTensor of size 1x10]读取多个向量。
输入为两个维度(batch的大小,每个batch的单词个数),输出则在两个维度上加上词向量的大小。
Input: LongTensor (N, W), N = mini-batch, W = number of indices to extract per mini-batchOutput: (N, W, embedding_dim)见代码
# an Embedding module containing 10 tensors of size 3embedding = nn.Embedding(10, 3)# 每批取两组,每组四个单词input = Variable(torch.LongTensor([[1,2,4,5],[4,3,2,9]]))a = embedding(input) # 输出2*4*3a[0],a[1]输出为:
(Variable containing: -1.2603 0.4337 0.4181 0.4458 -0.1987 0.4971 -0.5783 1.3640 0.7588 0.4956 -0.2379 -0.7678 [torch.FloatTensor of size 4x3], Variable containing: -0.5783 1.3640 0.7588 -0.5313 -0.3886 -0.6110 0.4458 -0.1987 0.4971 -1.3768 1.7323 0.4816 [torch.FloatTensor of size 4x3])以上这篇pytorch中的embedding词向量的使用方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
如何在pytorch中使用word2vec训练好的词向量torch.nn.Embedding()这个方法是在pytorch中将词向量和词对应起来的一个方法.一般
最近在工作中进行了NLP的内容,使用的还是Keras中embedding的词嵌入来做的。Keras中embedding层做一下介绍。中文文档地址:https:/
在读取https://github.com/Embedding/Chinese-Word-Vectors中的中文词向量时,选择了一个有3G多的txt文件,之前在
C++中list的使用方法及常用list操作总结一、List定义:List是stl实现的双向链表,与向量(vectors)相比,它允许快速的插入和删除,但是随机
Pytorch提取模型特征向量#-*-coding:utf-8-*-"""dj"""importtorchimporttorch.nnasnnimportosf