itertools —- 为高效循环而创建迭代器的函数


本模块实现一系列 iterator ,这些迭代器受到APL,Haskell和SML的启发。为了适用于Python,它们都被重新写过。

本模块标准化了一个快速、高效利用内存的核心工具集,这些工具本身或组合都很有用。它们一起形成了“迭代器代数”,这使得在纯Python中有可能创建简洁又高效的专用工具。

例如,SML有一个制表工具: tabulate(f),它可产生一个序列 f(0), f(1), …。在Python中可以组合 map()count() 实现: map(f, count())

这些内置工具同时也能很好地与 operator 模块中的高效函数配合使用。例如,我们可以将两个向量的点积映射到乘法运算符: sum(map(operator.mul, vector1, vector2))

无穷迭代器:

迭代器实参结果示例
count()start, [step]start, start+step, start+2*step, …count(10) —> 10 11 12 13 14 …
cycle()pp0, p1, … plast, p0, p1, …cycle('ABCD') —> A B C D A B C D …
repeat()elem [,n]elem, elem, elem, … 重复无限次或n次repeat(10, 3) —> 10 10 10

根据最短输入序列长度停止的迭代器:

迭代器实参结果示例
accumulate()p [,func]p0, p0+p1, p0+p1+p2, …accumulate([1,2,3,4,5]) —> 1 3 6 10 15
chain()p, q, …p0, p1, … plast, q0, q1, …chain('ABC', 'DEF') —> A B C D E F
chain.from_iterable()iterablep0, p1, … plast, q0, q1, …chain.from_iterable(['ABC', 'DEF']) —> A B C D E F
compress()data, selectors(d[0] if s[0]), (d[1] if s[1]), …compress('ABCDEF', [1,0,1,0,1,1]) —> A C E F
dropwhile()pred, seqseq[n], seq[n+1], … 从pred首次真值测试失败开始dropwhile(lambda x: x<5, [1,4,6,4,1]) —> 6 4 1
filterfalse()pred, seqseq中pred(x)为假值的元素,x是seq中的元素。filterfalse(lambda x: x%2, range(10)) —> 0 2 4 6 8
groupby()iterable[, key]根据key(v)值分组的迭代器
islice()seq, [start,] stop [, step]seq[start:stop:step]中的元素islice('ABCDEFG', 2, None) —> C D E F G
starmap()func, seqfunc(seq[0]), func(seq[1]), …starmap(pow, [(2,5), (3,2), (10,3)]) —> 32 9 1000
takewhile()pred, seqseq[0], seq[1], …, 直到pred真值测试失败takewhile(lambda x: x<5, [1,4,6,4,1]) —> 1 4
tee()it, nit1, it2, … itn 将一个迭代器拆分为n个迭代器
zip_longest()p, q, …(p[0], q[0]), (p[1], q[1]), …zip_longest('ABCD', 'xy', fillvalue='-') —> Ax By C- D-

排列组合迭代器:

迭代器实参结果
product()p, q, … [repeat=1]笛卡尔积,相当于嵌套的for循环
permutations()p[, r]长度r元组,所有可能的排列,无重复元素
combinations()p, r长度r元组,有序,无重复元素
combinations_with_replacement()p, r长度r元组,有序,元素可重复
例子结果
product('ABCD', repeat=2)AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD
permutations('ABCD', 2)AB AC AD BA BC BD CA CB CD DA DB DC
combinations('ABCD', 2)AB AC AD BC BD CD
combinations_with_replacement('ABCD',2)AA AB AC AD BB BC BD CC CD DD

Itertool函数

下列模块函数均创建并返回迭代器。有些迭代器不限制输出流长度,所以它们只应在能截断输出流的函数或循环中使用。

  • itertools.accumulate(iterable[, func, *, initial=None])
  • 创建一个迭代器,返回累积汇总值或其他双目运算函数的累积结果值(通过可选的 func 参数指定)。

如果提供了 func,它应当为带有两个参数的函数。 输入 iterable 的元素可以是能被 func 接受为参数的任意类型。 (例如,对于默认的加法运算,元素可以是任何可相加的类型包括 DecimalFraction。)

通常,输出的元素数量与输入的可迭代对象是一致的。 但是,如果提供了关键字参数 initial,则累加会以 initial 值开始,这样输出就比输入的可迭代对象多一个元素。

大致相当于:

  1. def accumulate(iterable, func=operator.add, *, initial=None):
  2. 'Return running totals'
  3. # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
  4. # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
  5. # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
  6. it = iter(iterable)
  7. total = initial
  8. if initial is None:
  9. try:
  10. total = next(it)
  11. except StopIteration:
  12. return
  13. yield total
  14. for element in it:
  15. total = func(total, element)
  16. yield total

func 参数有几种用法。它可以被设为 min() 最终得到一个最小值,或者设为 max() 最终得到一个最大值,或设为 operator.mul() 最终得到一个乘积。摊销表可通过累加利息和支付款项得到。给iterable设置初始值并只将参数 func 设为累加总数可以对一阶 递归关系 建模。

  1. >>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
  2. >>> list(accumulate(data, operator.mul)) # running product
  3. [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
  4. >>> list(accumulate(data, max)) # running maximum
  5. [3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
  6.  
  7. # Amortize a 5% loan of 1000 with 4 annual payments of 90
  8. >>> cashflows = [1000, -90, -90, -90, -90]
  9. >>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))
  10. [1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]
  11.  
  12. # Chaotic recurrence relation https://en.wikipedia.org/wiki/Logistic_map
  13. >>> logistic_map = lambda x, _: r * x * (1 - x)
  14. >>> r = 3.8
  15. >>> x0 = 0.4
  16. >>> inputs = repeat(x0, 36) # only the initial value is used
  17. >>> [format(x, '.2f') for x in accumulate(inputs, logistic_map)]
  18. ['0.40', '0.91', '0.30', '0.81', '0.60', '0.92', '0.29', '0.79', '0.63',
  19. '0.88', '0.39', '0.90', '0.33', '0.84', '0.52', '0.95', '0.18', '0.57',
  20. '0.93', '0.25', '0.71', '0.79', '0.63', '0.88', '0.39', '0.91', '0.32',
  21. '0.83', '0.54', '0.95', '0.20', '0.60', '0.91', '0.30', '0.80', '0.60']

参考一个类似函数 functools.reduce() ,它只返回一个最终累积值。

3.2 新版功能.

在 3.3 版更改: 增加可选参数 func

在 3.8 版更改: 添加了可选的 initial 形参。

  • itertools.chain(*iterables)
  • 创建一个迭代器,它首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素。可将多个序列处理为单个序列。大致相当于:
  1. def chain(*iterables):
  2. # chain('ABC', 'DEF') --> A B C D E F
  3. for it in iterables:
  4. for element in it:
  5. yield element
  • classmethod chain.fromiterable(_iterable)
  • 构建类似 chain() 迭代器的另一个选择。从一个单独的可迭代参数中得到链式输入,该参数是延迟计算的。大致相当于:
  1. def from_iterable(iterables):
  2. # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
  3. for it in iterables:
  4. for element in it:
  5. yield element
  • itertools.combinations(iterable, r)
  • 返回由输入 iterable 中元素组成长度为 r 的子序列。

组合按照字典序返回。所以如果输入 iterable 是有序的,生成的组合元组也是有序的。

即使元素的值相同,不同位置的元素也被认为是不同的。如果元素各自不同,那么每个组合中没有重复元素。

大致相当于:

  1. def combinations(iterable, r):
  2. # combinations('ABCD', 2) --> AB AC AD BC BD CD
  3. # combinations(range(4), 3) --> 012 013 023 123
  4. pool = tuple(iterable)
  5. n = len(pool)
  6. if r > n:
  7. return
  8. indices = list(range(r))
  9. yield tuple(pool[i] for i in indices)
  10. while True:
  11. for i in reversed(range(r)):
  12. if indices[i] != i + n - r:
  13. break
  14. else:
  15. return
  16. indices[i] += 1
  17. for j in range(i+1, r):
  18. indices[j] = indices[j-1] + 1
  19. yield tuple(pool[i] for i in indices)

combinations() 的代码可被改写为 permutations() 过滤后的子序列,(相对于元素在输入中的位置)元素不是有序的。

  1. def combinations(iterable, r):
  2. pool = tuple(iterable)
  3. n = len(pool)
  4. for indices in permutations(range(n), r):
  5. if sorted(indices) == list(indices):
  6. yield tuple(pool[i] for i in indices)

0 <= r <= n 时,返回项的个数是 n! / r! / (n-r)!;当 r > n 时,返回项个数为0。

  • itertools.combinationswith_replacement(_iterable, r)
  • 返回由输入 iterable 中元素组成的长度为 r 的子序列,允许每个元素可重复出现。

组合按照字典序返回。所以如果输入 iterable 是有序的,生成的组合元组也是有序的。

不同位置的元素是不同的,即使它们的值相同。因此如果输入中的元素都是不同的话,返回的组合中元素也都会不同。

大致相当于:

  1. def combinations_with_replacement(iterable, r):
  2. # combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
  3. pool = tuple(iterable)
  4. n = len(pool)
  5. if not n and r:
  6. return
  7. indices = [0] * r
  8. yield tuple(pool[i] for i in indices)
  9. while True:
  10. for i in reversed(range(r)):
  11. if indices[i] != n - 1:
  12. break
  13. else:
  14. return
  15. indices[i:] = [indices[i] + 1] * (r - i)
  16. yield tuple(pool[i] for i in indices)

combinations_with_replacement() 的代码可被改写为 production() 过滤后的子序列,(相对于元素在输入中的位置)元素不是有序的。

  1. def combinations_with_replacement(iterable, r):
  2. pool = tuple(iterable)
  3. n = len(pool)
  4. for indices in product(range(n), repeat=r):
  5. if sorted(indices) == list(indices):
  6. yield tuple(pool[i] for i in indices)

n > 0 时,返回项个数为 (n+r-1)! / r! / (n-1)!.

3.1 新版功能.

  • itertools.compress(data, selectors)
  • 创建一个迭代器,它返回 data 中经 selectors 真值测试为 True 的元素。迭代器在两者较短的长度处停止。大致相当于:
  1. def compress(data, selectors):
  2. # compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
  3. return (d for d, s in zip(data, selectors) if s)

3.1 新版功能.

  • itertools.count(start=0, step=1)
  • 创建一个迭代器,它从 start 值开始,返回均匀间隔的值。常用于 map() 中的实参来生成连续的数据点。此外,还用于 zip() 来添加序列号。大致相当于:
  1. def count(start=0, step=1):
  2. # count(10) --> 10 11 12 13 14 ...
  3. # count(2.5, 0.5) -> 2.5 3.0 3.5 ...
  4. n = start
  5. while True:
  6. yield n
  7. n += step

当对浮点数计数时,替换为乘法代码有时精度会更好,例如: (start + step * i for i in count())

在 3.1 版更改: 增加参数 step ,允许非整型。

  • itertools.cycle(iterable)
  • 创建一个迭代器,返回 iterable 中所有元素并保存一个副本。当取完 iterable 中所有元素,返回副本中的所有元素。无限重复。大致相当于:
  1. def cycle(iterable):
  2. # cycle('ABCD') --> A B C D A B C D A B C D ...
  3. saved = []
  4. for element in iterable:
  5. yield element
  6. saved.append(element)
  7. while saved:
  8. for element in saved:
  9. yield element

注意,该函数可能需要相当大的辅助空间(取决于 iterable 的长度)。

  • itertools.dropwhile(predicate, iterable)
  • 创建一个迭代器,如果 predicate 为true,迭代器丢弃这些元素,然后返回其他元素。注意,迭代器在 predicate 首次为false之前不会产生任何输出,所以可能需要一定长度的启动时间。大致相当于:
  1. def dropwhile(predicate, iterable):
  2. # dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1
  3. iterable = iter(iterable)
  4. for x in iterable:
  5. if not predicate(x):
  6. yield x
  7. break
  8. for x in iterable:
  9. yield x
  • itertools.filterfalse(predicate, iterable)
  • 创建一个迭代器,只返回 iterablepredicateFalse 的元素。如果 predicateNone,返回真值测试为false的元素。大致相当于:
  1. def filterfalse(predicate, iterable):
  2. # filterfalse(lambda x: x%2, range(10)) --> 0 2 4 6 8
  3. if predicate is None:
  4. predicate = bool
  5. for x in iterable:
  6. if not predicate(x):
  7. yield x
  • itertools.groupby(iterable, key=None)
  • 创建一个迭代器,返回 iterable 中连续的键和组。key 是一个计算元素键值函数。如果未指定或为 Nonekey 缺省为恒等函数(identity function),返回元素不变。一般来说,iterable 需用同一个键值函数预先排序。

groupby() 操作类似于Unix中的 uniq。当每次 key 函数产生的键值改变时,迭代器会分组或生成一个新组(这就是为什么通常需要使用同一个键值函数先对数据进行排序)。这种行为与SQL的GROUP BY操作不同,SQL的操作会忽略输入的顺序将相同键值的元素分在同组中。

返回的组本身也是一个迭代器,它与 groupby() 共享底层的可迭代对象。因为源是共享的,当 groupby() 对象向后迭代时,前一个组将消失。因此如果稍后还需要返回结果,可保存为列表:

  1. groups = []
  2. uniquekeys = []
  3. data = sorted(data, key=keyfunc)
  4. for k, g in groupby(data, keyfunc):
  5. groups.append(list(g)) # Store group iterator as a list
  6. uniquekeys.append(k)

groupby() 大致相当于:

  1. class groupby:
  2. # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
  3. # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
  4. def __init__(self, iterable, key=None):
  5. if key is None:
  6. key = lambda x: x
  7. self.keyfunc = key
  8. self.it = iter(iterable)
  9. self.tgtkey = self.currkey = self.currvalue = object()
  10. def __iter__(self):
  11. return self
  12. def __next__(self):
  13. self.id = object()
  14. while self.currkey == self.tgtkey:
  15. self.currvalue = next(self.it) # Exit on StopIteration
  16. self.currkey = self.keyfunc(self.currvalue)
  17. self.tgtkey = self.currkey
  18. return (self.currkey, self._grouper(self.tgtkey, self.id))
  19. def _grouper(self, tgtkey, id):
  20. while self.id is id and self.currkey == tgtkey:
  21. yield self.currvalue
  22. try:
  23. self.currvalue = next(self.it)
  24. except StopIteration:
  25. return
  26. self.currkey = self.keyfunc(self.currvalue)
  • itertools.islice(iterable, stop)
  • itertools.islice(iterable, start, stop[, step])
  • 创建一个迭代器,返回从 iterable 里选中的元素。如果 start 不是0,跳过 iterable 中的元素,直到到达 start 这个位置。之后迭代器连续返回元素,除非 step 设置的值很高导致被跳过。如果 stopNone,迭代器耗光为止;否则,在指定的位置停止。与普通的切片不同,islice() 不支持将 startstop ,或 step 设为负值。可用来从内部数据结构被压平的数据中提取相关字段(例如一个多行报告,它的名称字段出现在每三行上)。大致相当于:
  1. def islice(iterable, *args):
  2. # islice('ABCDEFG', 2) --> A B
  3. # islice('ABCDEFG', 2, 4) --> C D
  4. # islice('ABCDEFG', 2, None) --> C D E F G
  5. # islice('ABCDEFG', 0, None, 2) --> A C E G
  6. s = slice(*args)
  7. start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
  8. it = iter(range(start, stop, step))
  9. try:
  10. nexti = next(it)
  11. except StopIteration:
  12. # Consume *iterable* up to the *start* position.
  13. for i, element in zip(range(start), iterable):
  14. pass
  15. return
  16. try:
  17. for i, element in enumerate(iterable):
  18. if i == nexti:
  19. yield element
  20. nexti = next(it)
  21. except StopIteration:
  22. # Consume to *stop*.
  23. for i, element in zip(range(i + 1, stop), iterable):
  24. pass

如果 startNone,迭代从0开始。如果 stepNone ,步长缺省为1。

  • itertools.permutations(iterable, r=None)
  • 连续返回由 iterable 元素生成长度为 r 的排列。

如果 r 未指定或为 Noner 默认设置为 iterable 的长度,这种情况下,生成所有全长排列。

排列依字典序发出。因此,如果 iterable 是已排序的,排列元组将有序地产出。

即使元素的值相同,不同位置的元素也被认为是不同的。如果元素值都不同,每个排列中的元素值不会重复。

大致相当于:

  1. def permutations(iterable, r=None):
  2. # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
  3. # permutations(range(3)) --> 012 021 102 120 201 210
  4. pool = tuple(iterable)
  5. n = len(pool)
  6. r = n if r is None else r
  7. if r > n:
  8. return
  9. indices = list(range(n))
  10. cycles = list(range(n, n-r, -1))
  11. yield tuple(pool[i] for i in indices[:r])
  12. while n:
  13. for i in reversed(range(r)):
  14. cycles[i] -= 1
  15. if cycles[i] == 0:
  16. indices[i:] = indices[i+1:] + indices[i:i+1]
  17. cycles[i] = n - i
  18. else:
  19. j = cycles[i]
  20. indices[i], indices[-j] = indices[-j], indices[i]
  21. yield tuple(pool[i] for i in indices[:r])
  22. break
  23. else:
  24. return

permutations() 的代码也可被改写为 product() 的子序列,只要将含有重复元素(来自输入中同一位置的)的项排除。

  1. def permutations(iterable, r=None):
  2. pool = tuple(iterable)
  3. n = len(pool)
  4. r = n if r is None else r
  5. for indices in product(range(n), repeat=r):
  6. if len(set(indices)) == r:
  7. yield tuple(pool[i] for i in indices)

0 <= r <= n ,返回项个数为 n! / (n-r)! ;当 r > n ,返回项个数为0。

  • itertools.product(*iterables, repeat=1)
  • 可迭代对象输入的笛卡儿积。

大致相当于生成器表达式中的嵌套循环。例如, product(A, B)((x,y) for x in A for y in B) 返回结果一样。

嵌套循环像里程表那样循环变动,每次迭代时将最右侧的元素向后迭代。这种模式形成了一种字典序,因此如果输入的可迭代对象是已排序的,笛卡尔积元组依次序发出。

要计算可迭代对象自身的笛卡尔积,将可选参数 repeat 设定为要重复的次数。例如,product(A, repeat=4)product(A, A, A, A) 是一样的。

该函数大致相当于下面的代码,只不过实际实现方案不会在内存中创建中间结果。

  1. def product(*args, repeat=1):
  2. # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
  3. # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
  4. pools = [tuple(pool) for pool in args] * repeat
  5. result = [[]]
  6. for pool in pools:
  7. result = [x+[y] for x in result for y in pool]
  8. for prod in result:
  9. yield tuple(prod)
  • itertools.repeat(object[, times])
  • 创建一个迭代器,不断重复 object 。除非设定参数 times ,否则将无限重复。可用于 map() 函数中的参数,被调用函数可得到一个不变参数。也可用于 zip() 的参数以在元组记录中创建一个不变的部分。

大致相当于:

  1. def repeat(object, times=None):
  2. # repeat(10, 3) --> 10 10 10
  3. if times is None:
  4. while True:
  5. yield object
  6. else:
  7. for i in range(times):
  8. yield object

repeat 最常见的用途就是在 mapzip 提供一个常量流:

  1. >>> list(map(pow, range(10), repeat(2)))
  2. [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
  • itertools.starmap(function, iterable)
  • 创建一个迭代器,使用从可迭代对象中获取的参数来计算该函数。当参数对应的形参已从一个单独可迭代对象组合为元组时(数据已被“预组对”)可用此函数代替 map()map()starmap() 之间的区别可以类比 function(a,b)function(*c) 的区别。大致相当于:
  1. def starmap(function, iterable):
  2. # starmap(pow, [(2,5), (3,2), (10,3)]) --> 32 9 1000
  3. for args in iterable:
  4. yield function(*args)
  • itertools.takewhile(predicate, iterable)
  • 创建一个迭代器,只要 predicate 为真就从可迭代对象中返回元素。大致相当于:
  1. def takewhile(predicate, iterable):
  2. # takewhile(lambda x: x<5, [1,4,6,4,1]) --> 1 4
  3. for x in iterable:
  4. if predicate(x):
  5. yield x
  6. else:
  7. break
  • itertools.tee(iterable, n=2)
  • 从一个可迭代对象中返回 n 个独立的迭代器。

下面的Python代码能帮助解释 tee 做了什么(尽管实际的实现更复杂,而且仅使用了一个底层的 FIFO 队列)。

大致相当于:

  1. def tee(iterable, n=2):
  2. it = iter(iterable)
  3. deques = [collections.deque() for i in range(n)]
  4. def gen(mydeque):
  5. while True:
  6. if not mydeque: # when the local deque is empty
  7. try:
  8. newval = next(it) # fetch a new value and
  9. except StopIteration:
  10. return
  11. for d in deques: # load it to all the deques
  12. d.append(newval)
  13. yield mydeque.popleft()
  14. return tuple(gen(d) for d in deques)

一旦 tee() 实施了一次分裂,原有的 iterable 不应再被使用;否则tee对象无法得知 iterable 可能已向后迭代。

tee 迭代器不是线程安全的。当同时使用由同一个 tee() 调用所返回的迭代器时可能引发 RuntimeError,即使原本的 iterable 是线程安全的。

该迭代工具可能需要相当大的辅助存储空间(这取决于要保存多少临时数据)。通常,如果一个迭代器在另一个迭代器开始之前就要使用大部份或全部数据,使用 list() 会比 tee() 更快。

  • itertools.ziplongest(*iterables, _fillvalue=None)
  • 创建一个迭代器,从每个可迭代对象中收集元素。如果可迭代对象的长度未对齐,将根据 fillvalue 填充缺失值。迭代持续到耗光最长的可迭代对象。大致相当于:
  1. def zip_longest(*args, fillvalue=None):
  2. # zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-
  3. iterators = [iter(it) for it in args]
  4. num_active = len(iterators)
  5. if not num_active:
  6. return
  7. while True:
  8. values = []
  9. for i, it in enumerate(iterators):
  10. try:
  11. value = next(it)
  12. except StopIteration:
  13. num_active -= 1
  14. if not num_active:
  15. return
  16. iterators[i] = repeat(fillvalue)
  17. value = fillvalue
  18. values.append(value)
  19. yield tuple(values)

如果其中一个可迭代对象有无限长度,zip_longest() 函数应封装在限制调用次数的场景中(例如 islice()takewhile())。除非指定, fillvalue 默认为 None

itertools 配方

本节将展示如何使用现有的 itertools 作为基础构件来创建扩展的工具集。

基本上所有这些西方和许许多多其他的配方都可以通过 Python Package Index 上的 more-itertools 项目 来安装:

  1. pip install more-itertools

扩展的工具提供了与底层工具集相同的高性能。保持了超棒的内存利用率,因为一次只处理一个元素,而不是将整个可迭代对象加载到内存。代码量保持得很小,以函数式风格将这些工具连接在一起,有助于消除临时变量。速度依然很快,因为倾向于使用“矢量化”构件来取代解释器开销大的 for 循环和 generator

  1. def take(n, iterable):
  2. "Return first n items of the iterable as a list"
  3. return list(islice(iterable, n))
  4.  
  5. def prepend(value, iterator):
  6. "Prepend a single value in front of an iterator"
  7. # prepend(1, [2, 3, 4]) -> 1 2 3 4
  8. return chain([value], iterator)
  9.  
  10. def tabulate(function, start=0):
  11. "Return function(0), function(1), ..."
  12. return map(function, count(start))
  13.  
  14. def tail(n, iterable):
  15. "Return an iterator over the last n items"
  16. # tail(3, 'ABCDEFG') --> E F G
  17. return iter(collections.deque(iterable, maxlen=n))
  18.  
  19. def consume(iterator, n=None):
  20. "Advance the iterator n-steps ahead. If n is None, consume entirely."
  21. # Use functions that consume iterators at C speed.
  22. if n is None:
  23. # feed the entire iterator into a zero-length deque
  24. collections.deque(iterator, maxlen=0)
  25. else:
  26. # advance to the empty slice starting at position n
  27. next(islice(iterator, n, n), None)
  28.  
  29. def nth(iterable, n, default=None):
  30. "Returns the nth item or a default value"
  31. return next(islice(iterable, n, None), default)
  32.  
  33. def all_equal(iterable):
  34. "Returns True if all the elements are equal to each other"
  35. g = groupby(iterable)
  36. return next(g, True) and not next(g, False)
  37.  
  38. def quantify(iterable, pred=bool):
  39. "Count how many times the predicate is true"
  40. return sum(map(pred, iterable))
  41.  
  42. def padnone(iterable):
  43. """Returns the sequence elements and then returns None indefinitely.
  44.  
  45. Useful for emulating the behavior of the built-in map() function.
  46. """
  47. return chain(iterable, repeat(None))
  48.  
  49. def ncycles(iterable, n):
  50. "Returns the sequence elements n times"
  51. return chain.from_iterable(repeat(tuple(iterable), n))
  52.  
  53. def dotproduct(vec1, vec2):
  54. return sum(map(operator.mul, vec1, vec2))
  55.  
  56. def flatten(list_of_lists):
  57. "Flatten one level of nesting"
  58. return chain.from_iterable(list_of_lists)
  59.  
  60. def repeatfunc(func, times=None, *args):
  61. """Repeat calls to func with specified arguments.
  62.  
  63. Example: repeatfunc(random.random)
  64. """
  65. if times is None:
  66. return starmap(func, repeat(args))
  67. return starmap(func, repeat(args, times))
  68.  
  69. def pairwise(iterable):
  70. "s -> (s0,s1), (s1,s2), (s2, s3), ..."
  71. a, b = tee(iterable)
  72. next(b, None)
  73. return zip(a, b)
  74.  
  75. def grouper(iterable, n, fillvalue=None):
  76. "Collect data into fixed-length chunks or blocks"
  77. # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
  78. args = [iter(iterable)] * n
  79. return zip_longest(*args, fillvalue=fillvalue)
  80.  
  81. def roundrobin(*iterables):
  82. "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
  83. # Recipe credited to George Sakkis
  84. num_active = len(iterables)
  85. nexts = cycle(iter(it).__next__ for it in iterables)
  86. while num_active:
  87. try:
  88. for next in nexts:
  89. yield next()
  90. except StopIteration:
  91. # Remove the iterator we just exhausted from the cycle.
  92. num_active -= 1
  93. nexts = cycle(islice(nexts, num_active))
  94.  
  95. def partition(pred, iterable):
  96. 'Use a predicate to partition entries into false entries and true entries'
  97. # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
  98. t1, t2 = tee(iterable)
  99. return filterfalse(pred, t1), filter(pred, t2)
  100.  
  101. def powerset(iterable):
  102. "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
  103. s = list(iterable)
  104. return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
  105.  
  106. def unique_everseen(iterable, key=None):
  107. "List unique elements, preserving order. Remember all elements ever seen."
  108. # unique_everseen('AAAABBBCCDAABBB') --> A B C D
  109. # unique_everseen('ABBCcAD', str.lower) --> A B C D
  110. seen = set()
  111. seen_add = seen.add
  112. if key is None:
  113. for element in filterfalse(seen.__contains__, iterable):
  114. seen_add(element)
  115. yield element
  116. else:
  117. for element in iterable:
  118. k = key(element)
  119. if k not in seen:
  120. seen_add(k)
  121. yield element
  122.  
  123. def unique_justseen(iterable, key=None):
  124. "List unique elements, preserving order. Remember only the element just seen."
  125. # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
  126. # unique_justseen('ABBCcAD', str.lower) --> A B C A D
  127. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  128.  
  129. def iter_except(func, exception, first=None):
  130. """ Call a function repeatedly until an exception is raised.
  131.  
  132. Converts a call-until-exception interface to an iterator interface.
  133. Like builtins.iter(func, sentinel) but uses an exception instead
  134. of a sentinel to end the loop.
  135.  
  136. Examples:
  137. iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator
  138. iter_except(d.popitem, KeyError) # non-blocking dict iterator
  139. iter_except(d.popleft, IndexError) # non-blocking deque iterator
  140. iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue
  141. iter_except(s.pop, KeyError) # non-blocking set iterator
  142.  
  143. """
  144. try:
  145. if first is not None:
  146. yield first() # For database APIs needing an initial cast to db.first()
  147. while True:
  148. yield func()
  149. except exception:
  150. pass
  151.  
  152. def first_true(iterable, default=False, pred=None):
  153. """Returns the first true value in the iterable.
  154.  
  155. If no true value is found, returns *default*
  156.  
  157. If *pred* is not None, returns the first item
  158. for which pred(item) is true.
  159.  
  160. """
  161. # first_true([a,b,c], x) --> a or b or c or x
  162. # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
  163. return next(filter(pred, iterable), default)
  164.  
  165. def random_product(*args, repeat=1):
  166. "Random selection from itertools.product(*args, **kwds)"
  167. pools = [tuple(pool) for pool in args] * repeat
  168. return tuple(random.choice(pool) for pool in pools)
  169.  
  170. def random_permutation(iterable, r=None):
  171. "Random selection from itertools.permutations(iterable, r)"
  172. pool = tuple(iterable)
  173. r = len(pool) if r is None else r
  174. return tuple(random.sample(pool, r))
  175.  
  176. def random_combination(iterable, r):
  177. "Random selection from itertools.combinations(iterable, r)"
  178. pool = tuple(iterable)
  179. n = len(pool)
  180. indices = sorted(random.sample(range(n), r))
  181. return tuple(pool[i] for i in indices)
  182.  
  183. def random_combination_with_replacement(iterable, r):
  184. "Random selection from itertools.combinations_with_replacement(iterable, r)"
  185. pool = tuple(iterable)
  186. n = len(pool)
  187. indices = sorted(random.randrange(n) for i in range(r))
  188. return tuple(pool[i] for i in indices)
  189.  
  190. def nth_combination(iterable, r, index):
  191. 'Equivalent to list(combinations(iterable, r))[index]'
  192. pool = tuple(iterable)
  193. n = len(pool)
  194. if r < 0 or r > n:
  195. raise ValueError
  196. c = 1
  197. k = min(r, n-r)
  198. for i in range(1, k+1):
  199. c = c * (n - k + i) // i
  200. if index < 0:
  201. index += c
  202. if index < 0 or index >= c:
  203. raise IndexError
  204. result = []
  205. while r:
  206. c, n, r = c*r//n, n-1, r-1
  207. while index >= c:
  208. index -= c
  209. c, n = c*(n-r)//n, n-1
  210. result.append(pool[-1-n])
  211. return tuple(result)