pytorch数据读取

pytorch数据读取机制:

sampler生成索引index,根据索引从DataSet中获取图片和标签

1.torch.utils.data.DataLoader

功能:构建可迭代的数据装在器

dataset:Dataset类,决定数据从哪读取及如何读取

batchsize:批大小

num_works:是否多进程读取数据,当条件允许时,多进程读取数据会加快数据读取速度。

shuffle:每个epoch是否乱序

drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

DataLoader(dataset, batchsize=1, shuffle=False, batch_sampler=None, num_workers=0, collate_fn=None, pin_memeory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

epoch:所有训练样本都已输入到模型中,称为一个epoch

iteration:一批样本输入到模型中,称为一个iteration

batchsize:批大小,决定一个epoch有多少个iteration

例如:

样本总数:80, batchsize:8

1epoch = 10 iteraion

样本总数:87, batchsize:8

1 epoch = 10 iteration drop_last=True

1 epoch = 11 iteration drop_last=False

2.torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写

__getitem__()

getitem:接收一个索引,返回一个样本

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

人命币分类实例:

数据分割:

import os
import random
import shutil

def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

if __name__ == ‘__main__‘:

    random.seed(1)

    dataset_dir = os.path.join("..", "..", "data", "RMB_data")
    split_dir = os.path.join("..", "..", "data", "rmb_split")
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith(‘.jpg‘), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print(‘Class:{}, train:{}, valid:{}, test:{}‘.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))

创建Dataset

import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"1": 0, "100": 1}

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert(‘RGB‘)     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith(‘.jpg‘), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info

原文地址:https://www.cnblogs.com/haiboxiaobai/p/11749379.html

时间: 2024-10-29 06:51:40

pytorch数据读取的相关文章

DataReader对象(数据读取)

DataReader对象提供了一个只进只读的数据读取器,用于从查询结果中读取数据,它每次仅能读取一行数据. [常用属性]: FieldCount:获取当前行的列数: HasRows:表明查询结果中是否还存在未被读取的数据. [常用方法]: Close:关闭SqlDataReader对象: GetName:获取指定列的名称; Read:使SqlDataReader前进到下一条记录. [使用DataReader对象对数据库进行查询操作步骤]: 1.创建Connection对象: 2.打开数据库连接:

sas数据读取详解 四种读取数据方式以及数据指针的位置 、读取mess data的两个小工具、特殊的读取技巧、infile语句及其选项(dsd dlm missover truncover obs firstobs)、proc import、自定义缺失值

(The record length is the number of characters, including spaces, in a data line.) If your data lines are long, and it looks like SAS is not reading all your data, then use the LRECL= option in the INFILE statement to specify a record length at least

转载---CGImageSource对图像数据读取任务的抽象

转载地址:http://www.tanhao.me/pieces/1019.html CGImageSource是对图像数据读取任务的抽象,通过它可以获得图像对象.缩略图.图像的属性(包括Exif信息). 1.创建CGImageSourceRef 1 2 NSString *imagePath = [[NSBundle bundleForClass:self.class] pathForImageResource:@"test.png"]; CGImageSourceRef image

T31P电子秤数据读取

连接串口后先发送"CP\r\n"激活电子秤数据发送,收到的数据包是17字节的 using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace DotNet.ElecScales { using System.IO.Ports; using System.Text; using System.Threading; /// <summary> ///

工大助手--数据读取

工大助手--数据读取 实现功能 1)用户可选择获取入学以来所有已修课程的相关信息:课程代号.课程名.课程属性.学分.成绩等信息. 2)用户可选择获取特定已修课程的相关信息:课程代号.课程名.课程属性.学分.成绩等信息. 3)用户可获得特定时间段内的加权平均分(1学期.1学年.全部). 团队成员 13070003 张   帆 13070046 孙宇辰 13070004 崔   巍 13070006 王   奈 13070002 张雨帆 13070045 汪天米 数据读入 在上次博客中,我讲到了我所

android SharedPreferences简单应用 插入数据 读取数据

package com.sharedpreference; import java.text.SimpleDateFormat; import java.util.Date; import android.os.Bundle; import android.app.Activity; import android.content.SharedPreferences; import android.view.Menu; import android.view.View; import androi

10 张图帮你搞定 TensorFlow 数据读取机制

导读 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下tensorflow的数据读取机制,文章的最后还会给出实战代码以供参考. 一.tensorflow读取机制图解 首先需要思考的一个问题是,什么是数据读取?以图像数据为例,读取数据的过程可以用下图来表示: 假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003.jpg--我

ajax数据读取和绑定

如何进行ajax数据读取和绑定呢? 首先创建一个AJAX对象 实现数据绑定 实现隔行变色 编写表格排序的方法(实现按照年龄这一列进行排序) 通过文档碎片,把排序后的最新顺序,重新添加到tBody中(通过文档碎片,有效的避免了回流耗性能的问题,浏览器不用每当HTML结果发生改变,就重新对当前的页面进行渲染) 1 <!DOCTYPE html> 2 <html lang="en"> 3 <head> 4 <meta charset="UT

asp.net mvc4 razor视图 (之) 数据读取

@Html.Raw 或者直接访问,使用 Model属性. 参考这里:http://techo.luefher.com/coding/dot-net/mvc/how-to-access-your-model-data-in-net-mvc-with-razor-engine-for-beginners/ lambda表达式,类似这样: 数据从controller到view,因此如果要初始化,也是在controller里面进行,如下: asp.net mvc4 razor视图 (之) 数据读取