一、生成器介绍
1.1 概述
生成器指的是Generator对象,它不再像以往一样,一次性生成所有的数据,而是用一个,再生成一个,基于用户写的规则(条件)来生成数据,如果条件不成立,则生成结束。
1.2 实现方式
方式1:推导式;
方式2:yield关键字实现。
1.3 名词解释:迭代
迭代指的是 逐个的从容器类型中获取每一个元素的过程,称之为:迭代(遍历)
例如:列表,集合,字典,生成器等,都是可以遍历(迭代)的,所以它们也称之为:可迭代对象。
1.4 目的/好处
节约内存资源,减少内存占用。
1.5 如何从生成器中获取数据?
方式1:next()函数
next()函数是移动指针的,获取下一个元素。
方式2:遍历
1.6 示例代码
#案例:演示生成器推导式写法,获取生成器对象。
if __name__ == '__main__':
#1.生成器写法1:推导式写法
my_generator=(i for i in range(1,10))
print(f'my_generator:{my_generator}')
print(f'my_generator的类型:{type(my_generator)}')
print('-'*30)
#2.生成器不是一下生成所有的数据,二十用一个再生成1个
#如何从生成器中获取数据?
#1.next()函数 2.for循环遍历
#方式1.next()函数
#next()是移动指针的,获取下一个元素
print(next(my_generator))#1 range(1,10) 1->2
print(next(my_generator))#2 2->3
print('-'*30)
for i in my_generator:#
print(i)
运行结果:
二、yield关键字介绍
概述:
yield可以创建生成器对象,逐个地把每个元素放到生成器对象中,函数结束时,返回生成器对象。
示例代码:
'''
yield写法示例
'''
#需求:获取1~10之间的整数,生成器写法。
#1.定义函数,获取:生成器对象
def get_generator():
#yield写法,返回的是生成器对象
for i in range(1,11):
yield i
#2.测试上述代码
if __name__ == '__main__':
#3.调用函数,获取生成器对象
my_generator=get_generator()
print(type(my_generator))
#4.从生成器对象中,获取数据
#方式1.next()函数
print(next(my_generator))# 1
print(next(my_generator))# 2
print('-'*30)
# 方式2.遍历
for i in my_generator:
print(i)
运行结果:
三、生成器案例
案例
自定义数据迭代器,按照指定的条数生成批次数据。在AI模型的训练过程中,是把数据分批次喂给模型的,而不是一次性投喂。
需求:
自定义数据迭代器(dataloader),实现根据指定的数据条数,获取每批次的数据。
示例代码:
import math
#math.ceil()函数,获取天花板数,即:比这个数字大的所有整数中,最小的那个整数
print(math.ceil(5)) #5
print(math.ceil(5.1))#6
print(math.ceil(100/8))#13
def dataset_loader(batch_size):
'''
自定义函数,获取批次数据
:param batch_size: 每批次数据的条数
:return: 生成器对象,每个数据=1批的数据
'''
#1.读取源文件,获取到所有的数据
with open('./jaychou_lyrics.txt','r',encoding='utf-8') as src_f:
#一次性读取所有的行,并放到列表里.
list_data=src_f.readlines()
#2.获取数据的总条数.
line_count=len(list_data)
#3.根据数据的总条数,结合每批次的数据条数,计算:总批次数.
batch_count=math.ceil(line_count/batch_size)
#4,遍历 总批次数,获取到:每个批次的 编号,然后生成:该批次的数据.
for batch_idx in range(batch_count):
'''
推理过程:
假设batch_size=8,batch_count=13,即:13批,8条/批,则:
batch_idx=0,代表第1批数据,数据为:第1条~第8条,[0:8]
batch_idx=1,代表第2批数据,数据为:第9条~第16条,[8:16]
batch_idx=2,代表第3批数据,数据为:第17~24条,[16:24]
...
'''
yield list_data[batch_idx*batch_size:(batch_idx+1)*batch_size]
if __name__ == '__main__':
#5.获取生成器对象.
data_loader=dataset_loader(batch_size=8)
#6.获取第1批次的数据.
batch_data1=next(data_loader)
for line in batch_data1:
print(line,end='')
print('-'*30)
print(next(data_loader))
运行结果:
下期继续分享~