Pycharm(十七)生成器

发布于:2025-05-01 ⋅ 阅读:(35) ⋅ 点赞:(0)

一、生成器介绍

1.1 概述

生成器指的是Generator对象,它不再像以往一样,一次性生成所有的数据,而是用一个,再生成一个,基于用户写的规则(条件)来生成数据,如果条件不成立,则生成结束。

1.2 实现方式

        方式1:推导式;

        方式2:yield关键字实现。

1.3 名词解释:迭代

        迭代指的是 逐个的从容器类型中获取每一个元素的过程,称之为:迭代(遍历)

        例如:列表,集合,字典,生成器等,都是可以遍历(迭代)的,所以它们也称之为:可迭代对象。

1.4 目的/好处

        节约内存资源,减少内存占用。

1.5 如何从生成器中获取数据?

        方式1:next()函数

                next()函数是移动指针的,获取下一个元素。

        方式2:遍历

1.6 示例代码

#案例:演示生成器推导式写法,获取生成器对象。
if __name__ == '__main__':
    #1.生成器写法1:推导式写法
    my_generator=(i for i in range(1,10))
    print(f'my_generator:{my_generator}')
    print(f'my_generator的类型:{type(my_generator)}')
    print('-'*30)
    #2.生成器不是一下生成所有的数据,二十用一个再生成1个
    #如何从生成器中获取数据?
    #1.next()函数   2.for循环遍历
    #方式1.next()函数
    #next()是移动指针的,获取下一个元素
    print(next(my_generator))#1     range(1,10)  1->2
    print(next(my_generator))#2     2->3
    print('-'*30)
    for i in my_generator:#
        print(i)
运行结果:

二、yield关键字介绍

概述:

        yield可以创建生成器对象,逐个地把每个元素放到生成器对象中,函数结束时,返回生成器对象。

示例代码:
 

'''
yield写法示例
'''
#需求:获取1~10之间的整数,生成器写法。
#1.定义函数,获取:生成器对象
def get_generator():
    #yield写法,返回的是生成器对象
    for i in range(1,11):
        yield i
#2.测试上述代码
if __name__ == '__main__':
    #3.调用函数,获取生成器对象
    my_generator=get_generator()
    print(type(my_generator))
    #4.从生成器对象中,获取数据
    #方式1.next()函数
    print(next(my_generator))# 1
    print(next(my_generator))# 2
    print('-'*30)
    # 方式2.遍历
    for i in my_generator:
        print(i)
运行结果:

三、生成器案例 

案例

        自定义数据迭代器,按照指定的条数生成批次数据。在AI模型的训练过程中,是把数据分批次喂给模型的,而不是一次性投喂。

需求:
自定义数据迭代器(dataloader),实现根据指定的数据条数,获取每批次的数据。

示例代码:

import math
#math.ceil()函数,获取天花板数,即:比这个数字大的所有整数中,最小的那个整数
print(math.ceil(5)) #5
print(math.ceil(5.1))#6
print(math.ceil(100/8))#13
def dataset_loader(batch_size):
    '''
    自定义函数,获取批次数据
    :param batch_size: 每批次数据的条数
    :return: 生成器对象,每个数据=1批的数据
    '''
    #1.读取源文件,获取到所有的数据
    with open('./jaychou_lyrics.txt','r',encoding='utf-8') as src_f:
        #一次性读取所有的行,并放到列表里.
        list_data=src_f.readlines()
    #2.获取数据的总条数.
    line_count=len(list_data)
    #3.根据数据的总条数,结合每批次的数据条数,计算:总批次数.
    batch_count=math.ceil(line_count/batch_size)
    #4,遍历 总批次数,获取到:每个批次的 编号,然后生成:该批次的数据.
    for batch_idx in range(batch_count):
        '''
        推理过程:
            假设batch_size=8,batch_count=13,即:13批,8条/批,则:
            batch_idx=0,代表第1批数据,数据为:第1条~第8条,[0:8]
            batch_idx=1,代表第2批数据,数据为:第9条~第16条,[8:16]
            batch_idx=2,代表第3批数据,数据为:第17~24条,[16:24]
            ...
        '''
        yield list_data[batch_idx*batch_size:(batch_idx+1)*batch_size]
if __name__ == '__main__':
    #5.获取生成器对象.
    data_loader=dataset_loader(batch_size=8)
    #6.获取第1批次的数据.
    batch_data1=next(data_loader)
    for line in batch_data1:
        print(line,end='')
    print('-'*30)
    print(next(data_loader))
运行结果:

下期继续分享~