时间:2021-05-22
pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。
torch.utils.data
torch的这个文件包含了一些关于数据集处理的类。
class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。
class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。
class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。
class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。
class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。
class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。
class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。
自定义数据集
自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。
整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。
import torchclass myDataset(torch.nn.data.Dataset): def __init__(self, dataSource) self.dataSource = dataSource def __getitem__(self, index): element = self.dataSource[index] return element def __len__(self): return len(self.dataSource)train_data = myDataset(dataSource)自定义数据集加载器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
dataset (Dataset) – 需要加载的数据集(可以是自定义或者自带的数据集)。
batch_size – batch的大小(可选项,默认值为1)。
shuffle – 是否在每个epoch中shuffle整个数据集, 默认值为False。
sampler – 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。
num_workers – 表示读取样本的线程数, 0表示只有主线程。
collate_fn – 合并一个样本列表称为一个batch。
pin_memory – 是否在返回数据之前将张量拷贝到CUDA。
drop_last (bool, optional) – 设置是否丢弃最后一个不完整的batch,默认为False。
timeout – 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。
train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
关于Pytorch中怎么自定义Dataset数据集类、怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述。现在的问题:有的时
自定义数据集在训练深度学习模型之前,样本集的制作非常重要。在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,下面完整的试验自定义样本集的整个流
思路自定义方法,使用Wrapper,自定义映射结果集Mapper接口packagecom.mozq.boot.mpsand01.dao;importcom.ba
手写一个通用加载中、显示数据、加载失败、空数据的LoadingView框架。定义3个布局:加载中,加载失败,空数据加载中:加载失败:空数据:自定义一个Loadi
自定义数据类型表.版本2.数据类型消息类型.成员键盘消息,文本型自定义数据类型使用代码.版本2.程序集窗口程序集1.子程序_按钮1_被单击.局部变量接收返回,消