生成器¶

while 循环通常有这样的形式:

  1. <do setup>
  2. result = []
  3. while True:
  4. <generate value>
  5. result.append(value)
  6. if <done>:
  7. break

使用迭代器实现这样的循环:

  1. class GenericIterator(object):
  2. def __init__(self, ...):
  3. <do setup>
  4. # 需要额外储存状态
  5. <store state>
  6. def next(self):
  7. <load state>
  8. <generate value>
  9. if <done>:
  10. raise StopIteration()
  11. <store state>
  12. return value

更简单的,可以使用生成器:

  1. def generator(...):
  2. <do setup>
  3. while True:
  4. <generate value>
  5. # yield 说明这个函数可以返回多个值!
  6. yield value
  7. if <done>:
  8. break

生成器使用 yield 关键字将值输出,而迭代器则通过 nextreturn 将值返回;与迭代器不同的是,生成器会自动记录当前的状态,而迭代器则需要进行额外的操作来记录当前的状态。

对于之前的 collatz 猜想,简单循环的实现如下:

In [1]:

  1. def collatz(n):
  2. sequence = []
  3. while n != 1:
  4. if n % 2 == 0:
  5. n /= 2
  6. else:
  7. n = 3*n + 1
  8. sequence.append(n)
  9. return sequence
  10.  
  11. for x in collatz(7):
  12. print x,
  1. 22 11 34 17 52 26 13 40 20 10 5 16 8 4 2 1

迭代器的版本如下:

In [2]:

  1. class Collatz(object):
  2. def __init__(self, start):
  3. self.value = start
  4.  
  5. def __iter__(self):
  6. return self
  7.  
  8. def next(self):
  9. if self.value == 1:
  10. raise StopIteration()
  11. elif self.value % 2 == 0:
  12. self.value = self.value/2
  13. else:
  14. self.value = 3*self.value + 1
  15. return self.value
  16.  
  17. for x in Collatz(7):
  18. print x,
  1. 22 11 34 17 52 26 13 40 20 10 5 16 8 4 2 1

生成器的版本如下:

In [3]:

  1. def collatz(n):
  2. while n != 1:
  3. if n % 2 == 0:
  4. n /= 2
  5. else:
  6. n = 3*n + 1
  7. yield n
  8.  
  9. for x in collatz(7):
  10. print x,
  1. 22 11 34 17 52 26 13 40 20 10 5 16 8 4 2 1

事实上,生成器也是一种迭代器:

In [4]:

  1. x = collatz(7)
  2. print x
  1. <generator object collatz at 0x0000000003B63750>

它支持 next 方法,返回下一个 yield 的值:

In [5]:

  1. print x.next()
  2. print x.next()
  1. 22
  2. 11

iter 方法返回的是它本身:

In [6]:

  1. print x.__iter__()
  1. <generator object collatz at 0x0000000003B63750>

之前的二叉树迭代器可以改写为更简单的生成器模式来进行中序遍历:

In [7]:

  1. class BinaryTree(object):
  2. def __init__(self, value, left=None, right=None):
  3. self.value = value
  4. self.left = left
  5. self.right = right
  6.  
  7. def __iter__(self):
  8. # 将迭代器设为生成器方法
  9. return self.inorder()
  10.  
  11. def inorder(self):
  12. # traverse the left branch
  13. if self.left is not None:
  14. for value in self.left:
  15. yield value
  16.  
  17. # yield node's value
  18. yield self.value
  19.  
  20. # traverse the right branch
  21. if self.right is not None:
  22. for value in self.right:
  23. yield value

非递归的实现:

In [9]:

  1. def inorder(self):
  2. node = self
  3. stack = []
  4. while len(stack) > 0 or node is not None:
  5. while node is not None:
  6. stack.append(node)
  7. node = node.left
  8. node = stack.pop()
  9. yield node.value
  10. node = node.right

In [10]:

  1. tree = BinaryTree(
  2. left=BinaryTree(
  3. left=BinaryTree(1),
  4. value=2,
  5. right=BinaryTree(
  6. left=BinaryTree(3),
  7. value=4,
  8. right=BinaryTree(5)
  9. ),
  10. ),
  11. value=6,
  12. right=BinaryTree(
  13. value=7,
  14. right=BinaryTree(8)
  15. )
  16. )
  17. for value in tree:
  18. print value,
  1. 1 2 3 4 5 6 7 8

原文: https://nbviewer.jupyter.org/github/lijin-THU/notes-python/blob/master/05-advanced-python/05.10-generators.ipynb