itertools —- 为高效循环而创建迭代器的函数


本模块实现一系列 iterator ,这些迭代器受到APL,Haskell和SML的启发。为了适用于Python,它们都被重新写过。

本模块标准化了一个快速、高效利用内存的核心工具集,这些工具本身或组合都很有用。它们一起形成了“迭代器代数”,这使得在纯Python中有可能创建简洁又高效的专用工具。

例如,SML有一个制表工具: tabulate(f),它可产生一个序列 f(0), f(1), ...。在Python中可以组合 map()count() 实现: map(f, count())

这些内置工具同时也能很好地与 operator 模块中的高效函数配合使用。例如,我们可以将两个向量的点积映射到乘法运算符: sum(map(operator.mul, vector1, vector2))

无穷迭代器:

迭代器

实参

结果

示例

count()

start, [step]

start, start+step, start+2*step, …

count(10) —> 10 11 12 13 14 …

cycle()

p

p0, p1, … plast, p0, p1, …

cycle(‘ABCD’) —> A B C D A B C D …

repeat()

elem [,n]

elem, elem, elem, … 重复无限次或n次

repeat(10, 3) —> 10 10 10

根据最短输入序列长度停止的迭代器:

迭代器

实参

结果

示例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, …

accumulate([1,2,3,4,5]) —> 1 3 6 10 15

chain()

p, q, …

p0, p1, … plast, q0, q1, …

chain(‘ABC’, ‘DEF’) —> A B C D E F

chain.from_iterable()

iterable — 可迭代对象

p0, p1, … plast, q0, q1, …

chain.from_iterable([‘ABC’, ‘DEF’]) —> A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), …

compress(‘ABCDEF’, [1,0,1,0,1,1]) —> A C E F

dropwhile()

pred, seq

seq[n], seq[n+1], … 从pred首次真值测试失败开始

dropwhile(lambda x: x<5, [1,4,6,4,1]) —> 6 4 1

filterfalse()

pred, seq

seq中pred(x)为假值的元素,x是seq中的元素。

filterfalse(lambda x: x%2, range(10)) —> 0 2 4 6 8

groupby()

iterable[, key]

根据key(v)值分组的迭代器

islice()

seq, [start,] stop [, step]

seq[start:stop:step]中的元素

islice(‘ABCDEFG’, 2, None) —> C D E F G

pairwise()

iterable — 可迭代对象

(p[0], p[1]), (p[1], p[2])

pairwise(‘ABCDEFG’) —> AB BC CD DE EF FG

starmap()

func, seq

func(seq[0]), func(seq[1]), …

starmap(pow, [(2,5), (3,2), (10,3)]) —> 32 9 1000

takewhile()

pred, seq

seq[0], seq[1], …, 直到pred真值测试失败

takewhile(lambda x: x<5, [1,4,6,4,1]) —> 1 4

tee()

it, n

it1, it2, … itn 将一个迭代器拆分为n个迭代器

zip_longest()

p, q, …

(p[0], q[0]), (p[1], q[1]), …

zip_longest(‘ABCD’, ‘xy’, fillvalue=’-‘) —> Ax By C- D-

排列组合迭代器:

迭代器

实参

结果

product()

p, q, … [repeat=1]

笛卡尔积,相当于嵌套的for循环

permutations()

p[, r]

长度r元组,所有可能的排列,无重复元素

combinations()

p, r

长度r元组,有序,无重复元素

combinations_with_replacement()

p, r

长度r元组,有序,元素可重复

例子

结果

product(‘ABCD’, repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations(‘ABCD’, 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations(‘ABCD’, 2)

AB AC AD BC BD CD

combinations_with_replacement(‘ABCD’, 2)

AA AB AC AD BB BC BD CC CD DD

Itertool函数

下列模块函数均创建并返回迭代器。有些迭代器不限制输出流长度,所以它们只应在能截断输出流的函数或循环中使用。

itertools.accumulate(iterable[, func, **, initial=None*])

创建一个迭代器,返回累积汇总值或其他双目运算函数的累积结果值(通过可选的 func 参数指定)。

如果提供了 func,它应当为带有两个参数的函数。 输入 iterable 的元素可以是能被 func 接受为参数的任意类型。 (例如,对于默认的加法运算,元素可以是任何可相加的类型包括 DecimalFraction。)

通常,输出的元素数量与输入的可迭代对象是一致的。 但是,如果提供了关键字参数 initial,则累加会以 initial 值开始,这样输出就比输入的可迭代对象多一个元素。

大致相当于:

  1. def accumulate(iterable, func=operator.add, *, initial=None):
  2. 'Return running totals'
  3. # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
  4. # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
  5. # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
  6. it = iter(iterable)
  7. total = initial
  8. if initial is None:
  9. try:
  10. total = next(it)
  11. except StopIteration:
  12. return
  13. yield total
  14. for element in it:
  15. total = func(total, element)
  16. yield total

There are a number of uses for the func argument. It can be set to min() for a running minimum, max() for a running maximum, or operator.mul() for a running product. Amortization tables can be built by accumulating interest and applying payments:

  1. >>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
  2. >>> list(accumulate(data, operator.mul)) # running product
  3. [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
  4. >>> list(accumulate(data, max)) # running maximum
  5. [3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
  6. # Amortize a 5% loan of 1000 with 4 annual payments of 90
  7. >>> cashflows = [1000, -90, -90, -90, -90]
  8. >>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))
  9. [1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]

参考一个类似函数 functools.reduce() ,它只返回一个最终累积值。

3.2 新版功能.

在 3.3 版更改: 增加可选参数 func

在 3.8 版更改: 添加了可选的 initial 形参。

itertools.chain(\iterables*)

创建一个迭代器,它首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素。可将多个序列处理为单个序列。大致相当于:

  1. def chain(*iterables):
  2. # chain('ABC', 'DEF') --> A B C D E F
  3. for it in iterables:
  4. for element in it:
  5. yield element

classmethod chain.from_iterable(iterable)

构建类似 chain() 迭代器的另一个选择。从一个单独的可迭代参数中得到链式输入,该参数是延迟计算的。大致相当于:

  1. def from_iterable(iterables):
  2. # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
  3. for it in iterables:
  4. for element in it:
  5. yield element

itertools.combinations(iterable, r)

返回由输入 iterable 中元素组成长度为 r 的子序列。

The combination tuples are emitted in lexicographic ordering according to the order of the input iterable. So, if the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. So if the input elements are unique, there will be no repeated values in each combination.

大致相当于:

  1. def combinations(iterable, r):
  2. # combinations('ABCD', 2) --> AB AC AD BC BD CD
  3. # combinations(range(4), 3) --> 012 013 023 123
  4. pool = tuple(iterable)
  5. n = len(pool)
  6. if r > n:
  7. return
  8. indices = list(range(r))
  9. yield tuple(pool[i] for i in indices)
  10. while True:
  11. for i in reversed(range(r)):
  12. if indices[i] != i + n - r:
  13. break
  14. else:
  15. return
  16. indices[i] += 1
  17. for j in range(i+1, r):
  18. indices[j] = indices[j-1] + 1
  19. yield tuple(pool[i] for i in indices)

combinations() 的代码可被改写为 permutations() 过滤后的子序列,(相对于元素在输入中的位置)元素不是有序的。

  1. def combinations(iterable, r):
  2. pool = tuple(iterable)
  3. n = len(pool)
  4. for indices in permutations(range(n), r):
  5. if sorted(indices) == list(indices):
  6. yield tuple(pool[i] for i in indices)

0 <= r <= n 时,返回项的个数是 n! / r! / (n-r)!;当 r > n 时,返回项个数为0。

itertools.combinations_with_replacement(iterable, r)

返回由输入 iterable 中元素组成的长度为 r 的子序列,允许每个元素可重复出现。

The combination tuples are emitted in lexicographic ordering according to the order of the input iterable. So, if the input iterable is sorted, the output tuples will be produced in sorted order.

不同位置的元素是不同的,即使它们的值相同。因此如果输入中的元素都是不同的话,返回的组合中元素也都会不同。

大致相当于:

  1. def combinations_with_replacement(iterable, r):
  2. # combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
  3. pool = tuple(iterable)
  4. n = len(pool)
  5. if not n and r:
  6. return
  7. indices = [0] * r
  8. yield tuple(pool[i] for i in indices)
  9. while True:
  10. for i in reversed(range(r)):
  11. if indices[i] != n - 1:
  12. break
  13. else:
  14. return
  15. indices[i:] = [indices[i] + 1] * (r - i)
  16. yield tuple(pool[i] for i in indices)

combinations_with_replacement() 的代码可被改写为 production() 过滤后的子序列,(相对于元素在输入中的位置)元素不是有序的。

  1. def combinations_with_replacement(iterable, r):
  2. pool = tuple(iterable)
  3. n = len(pool)
  4. for indices in product(range(n), repeat=r):
  5. if sorted(indices) == list(indices):
  6. yield tuple(pool[i] for i in indices)

n > 0 时,返回项个数为 (n+r-1)! / r! / (n-1)!.

3.1 新版功能.

itertools.compress(data, selectors)

创建一个迭代器,它返回 data 中经 selectors 真值测试为 True 的元素。迭代器在两者较短的长度处停止。大致相当于:

  1. def compress(data, selectors):
  2. # compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
  3. return (d for d, s in zip(data, selectors) if s)

3.1 新版功能.

itertools.count(start=0, step=1)

创建一个迭代器,它从 start 值开始,返回均匀间隔的值。常用于 map() 中的实参来生成连续的数据点。此外,还用于 zip() 来添加序列号。大致相当于:

  1. def count(start=0, step=1):
  2. # count(10) --> 10 11 12 13 14 ...
  3. # count(2.5, 0.5) --> 2.5 3.0 3.5 ...
  4. n = start
  5. while True:
  6. yield n
  7. n += step

当对浮点数计数时,替换为乘法代码有时精度会更好,例如: (start + step * i for i in count())

在 3.1 版更改: 增加参数 step ,允许非整型。

itertools.cycle(iterable)

创建一个迭代器,返回 iterable 中所有元素并保存一个副本。当取完 iterable 中所有元素,返回副本中的所有元素。无限重复。大致相当于:

  1. def cycle(iterable):
  2. # cycle('ABCD') --> A B C D A B C D A B C D ...
  3. saved = []
  4. for element in iterable:
  5. yield element
  6. saved.append(element)
  7. while saved:
  8. for element in saved:
  9. yield element

注意,该函数可能需要相当大的辅助空间(取决于 iterable 的长度)。

itertools.dropwhile(predicate, iterable)

创建一个迭代器,如果 predicate 为true,迭代器丢弃这些元素,然后返回其他元素。注意,迭代器在 predicate 首次为false之前不会产生任何输出,所以可能需要一定长度的启动时间。大致相当于:

  1. def dropwhile(predicate, iterable):
  2. # dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1
  3. iterable = iter(iterable)
  4. for x in iterable:
  5. if not predicate(x):
  6. yield x
  7. break
  8. for x in iterable:
  9. yield x

itertools.filterfalse(predicate, iterable)

创建一个迭代器,只返回 iterablepredicateFalse 的元素。如果 predicateNone,返回真值测试为false的元素。大致相当于:

  1. def filterfalse(predicate, iterable):
  2. # filterfalse(lambda x: x%2, range(10)) --> 0 2 4 6 8
  3. if predicate is None:
  4. predicate = bool
  5. for x in iterable:
  6. if not predicate(x):
  7. yield x

itertools.groupby(iterable, key=None)

创建一个迭代器,返回 iterable 中连续的键和组。key 是一个计算元素键值函数。如果未指定或为 Nonekey 缺省为恒等函数(identity function),返回元素不变。一般来说,iterable 需用同一个键值函数预先排序。

groupby() 操作类似于Unix中的 uniq。当每次 key 函数产生的键值改变时,迭代器会分组或生成一个新组(这就是为什么通常需要使用同一个键值函数先对数据进行排序)。这种行为与SQL的GROUP BY操作不同,SQL的操作会忽略输入的顺序将相同键值的元素分在同组中。

返回的组本身也是一个迭代器,它与 groupby() 共享底层的可迭代对象。因为源是共享的,当 groupby() 对象向后迭代时,前一个组将消失。因此如果稍后还需要返回结果,可保存为列表:

  1. groups = []
  2. uniquekeys = []
  3. data = sorted(data, key=keyfunc)
  4. for k, g in groupby(data, keyfunc):
  5. groups.append(list(g)) # Store group iterator as a list
  6. uniquekeys.append(k)

groupby() 大致相当于:

  1. class groupby:
  2. # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
  3. # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
  4. def __init__(self, iterable, key=None):
  5. if key is None:
  6. key = lambda x: x
  7. self.keyfunc = key
  8. self.it = iter(iterable)
  9. self.tgtkey = self.currkey = self.currvalue = object()
  10. def __iter__(self):
  11. return self
  12. def __next__(self):
  13. self.id = object()
  14. while self.currkey == self.tgtkey:
  15. self.currvalue = next(self.it) # Exit on StopIteration
  16. self.currkey = self.keyfunc(self.currvalue)
  17. self.tgtkey = self.currkey
  18. return (self.currkey, self._grouper(self.tgtkey, self.id))
  19. def _grouper(self, tgtkey, id):
  20. while self.id is id and self.currkey == tgtkey:
  21. yield self.currvalue
  22. try:
  23. self.currvalue = next(self.it)
  24. except StopIteration:
  25. return
  26. self.currkey = self.keyfunc(self.currvalue)

itertools.islice(iterable, stop)

itertools.islice(iterable, start, stop[, step])

Make an iterator that returns selected elements from the iterable. If start is non-zero, then elements from the iterable are skipped until start is reached. Afterward, elements are returned consecutively unless step is set higher than one which results in items being skipped. If stop is None, then iteration continues until the iterator is exhausted, if at all; otherwise, it stops at the specified position.

如果 startNone,迭代从0开始。如果 stepNone ,步长缺省为1。

Unlike regular slicing, islice() does not support negative values for start, stop, or step. Can be used to extract related fields from data where the internal structure has been flattened (for example, a multi-line report may list a name field on every third line).

大致相当于:

  1. def islice(iterable, *args):
  2. # islice('ABCDEFG', 2) --> A B
  3. # islice('ABCDEFG', 2, 4) --> C D
  4. # islice('ABCDEFG', 2, None) --> C D E F G
  5. # islice('ABCDEFG', 0, None, 2) --> A C E G
  6. s = slice(*args)
  7. start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
  8. it = iter(range(start, stop, step))
  9. try:
  10. nexti = next(it)
  11. except StopIteration:
  12. # Consume *iterable* up to the *start* position.
  13. for i, element in zip(range(start), iterable):
  14. pass
  15. return
  16. try:
  17. for i, element in enumerate(iterable):
  18. if i == nexti:
  19. yield element
  20. nexti = next(it)
  21. except StopIteration:
  22. # Consume to *stop*.
  23. for i, element in zip(range(i + 1, stop), iterable):
  24. pass

itertools.pairwise(iterable)

返回从输入 iterable 中获取的连续重叠对。

输出迭代器中 2 元组的数量将比输入的数量少一个。 如果输入可迭代对象中少于两个值则它将为空。

大致相当于:

  1. def pairwise(iterable):
  2. # pairwise('ABCDEFG') --> AB BC CD DE EF FG
  3. a, b = tee(iterable)
  4. next(b, None)
  5. return zip(a, b)

3.10 新版功能.

itertools.permutations(iterable, r=None)

连续返回由 iterable 元素生成长度为 r 的排列。

如果 r 未指定或为 Noner 默认设置为 iterable 的长度,这种情况下,生成所有全长排列。

The permutation tuples are emitted in lexicographic order according to the order of the input iterable. So, if the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. So if the input elements are unique, there will be no repeated values within a permutation.

大致相当于:

  1. def permutations(iterable, r=None):
  2. # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
  3. # permutations(range(3)) --> 012 021 102 120 201 210
  4. pool = tuple(iterable)
  5. n = len(pool)
  6. r = n if r is None else r
  7. if r > n:
  8. return
  9. indices = list(range(n))
  10. cycles = list(range(n, n-r, -1))
  11. yield tuple(pool[i] for i in indices[:r])
  12. while n:
  13. for i in reversed(range(r)):
  14. cycles[i] -= 1
  15. if cycles[i] == 0:
  16. indices[i:] = indices[i+1:] + indices[i:i+1]
  17. cycles[i] = n - i
  18. else:
  19. j = cycles[i]
  20. indices[i], indices[-j] = indices[-j], indices[i]
  21. yield tuple(pool[i] for i in indices[:r])
  22. break
  23. else:
  24. return

permutations() 的代码也可被改写为 product() 的子序列,只要将含有重复元素(来自输入中同一位置的)的项排除。

  1. def permutations(iterable, r=None):
  2. pool = tuple(iterable)
  3. n = len(pool)
  4. r = n if r is None else r
  5. for indices in product(range(n), repeat=r):
  6. if len(set(indices)) == r:
  7. yield tuple(pool[i] for i in indices)

0 <= r <= n ,返回项个数为 n! / (n-r)! ;当 r > n ,返回项个数为0。

itertools.product(\iterables, repeat=1*)

可迭代对象输入的笛卡儿积。

大致相当于生成器表达式中的嵌套循环。例如, product(A, B)((x,y) for x in A for y in B) 返回结果一样。

嵌套循环像里程表那样循环变动,每次迭代时将最右侧的元素向后迭代。这种模式形成了一种字典序,因此如果输入的可迭代对象是已排序的,笛卡尔积元组依次序发出。

要计算可迭代对象自身的笛卡尔积,将可选参数 repeat 设定为要重复的次数。例如,product(A, repeat=4)product(A, A, A, A) 是一样的。

该函数大致相当于下面的代码,只不过实际实现方案不会在内存中创建中间结果。

  1. def product(*args, repeat=1):
  2. # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
  3. # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
  4. pools = [tuple(pool) for pool in args] * repeat
  5. result = [[]]
  6. for pool in pools:
  7. result = [x+[y] for x in result for y in pool]
  8. for prod in result:
  9. yield tuple(prod)

product() 运行之前,它会完全耗尽输入的可迭代对象,在内存中保留值的临时池以生成结果积。 相应地,它只适用于有限的输入。

itertools.repeat(object[, times])

Make an iterator that returns object over and over again. Runs indefinitely unless the times argument is specified.

大致相当于:

  1. def repeat(object, times=None):
  2. # repeat(10, 3) --> 10 10 10
  3. if times is None:
  4. while True:
  5. yield object
  6. else:
  7. for i in range(times):
  8. yield object

A common use for repeat is to supply a stream of constant values to map or zip:

  1. >>> list(map(pow, range(10), repeat(2)))
  2. [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

itertools.starmap(function, iterable)

Make an iterator that computes the function using arguments obtained from the iterable. Used instead of map() when argument parameters are already grouped in tuples from a single iterable (when the data has been “pre-zipped”).

The difference between map() and starmap() parallels the distinction between function(a,b) and function(*c). Roughly equivalent to:

  1. def starmap(function, iterable):
  2. # starmap(pow, [(2,5), (3,2), (10,3)]) --> 32 9 1000
  3. for args in iterable:
  4. yield function(*args)

itertools.takewhile(predicate, iterable)

创建一个迭代器,只要 predicate 为真就从可迭代对象中返回元素。大致相当于:

  1. def takewhile(predicate, iterable):
  2. # takewhile(lambda x: x<5, [1,4,6,4,1]) --> 1 4
  3. for x in iterable:
  4. if predicate(x):
  5. yield x
  6. else:
  7. break

itertools.tee(iterable, n=2)

从一个可迭代对象中返回 n 个独立的迭代器。

The following Python code helps explain what tee does (although the actual implementation is more complex and uses only a single underlying FIFO queue):

  1. def tee(iterable, n=2):
  2. it = iter(iterable)
  3. deques = [collections.deque() for i in range(n)]
  4. def gen(mydeque):
  5. while True:
  6. if not mydeque: # when the local deque is empty
  7. try:
  8. newval = next(it) # fetch a new value and
  9. except StopIteration:
  10. return
  11. for d in deques: # load it to all the deques
  12. d.append(newval)
  13. yield mydeque.popleft()
  14. return tuple(gen(d) for d in deques)

Once a tee() has been created, the original iterable should not be used anywhere else; otherwise, the iterable could get advanced without the tee objects being informed.

tee 迭代器不是线程安全的。当同时使用由同一个 tee() 调用所返回的迭代器时可能引发 RuntimeError,即使原本的 iterable 是线程安全的。

该迭代工具可能需要相当大的辅助存储空间(这取决于要保存多少临时数据)。通常,如果一个迭代器在另一个迭代器开始之前就要使用大部份或全部数据,使用 list() 会比 tee() 更快。

itertools.zip_longest(\iterables, fillvalue=None*)

创建一个迭代器,从每个可迭代对象中收集元素。如果可迭代对象的长度未对齐,将根据 fillvalue 填充缺失值。迭代持续到耗光最长的可迭代对象。大致相当于:

  1. def zip_longest(*args, fillvalue=None):
  2. # zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D-
  3. iterators = [iter(it) for it in args]
  4. num_active = len(iterators)
  5. if not num_active:
  6. return
  7. while True:
  8. values = []
  9. for i, it in enumerate(iterators):
  10. try:
  11. value = next(it)
  12. except StopIteration:
  13. num_active -= 1
  14. if not num_active:
  15. return
  16. iterators[i] = repeat(fillvalue)
  17. value = fillvalue
  18. values.append(value)
  19. yield tuple(values)

如果其中一个可迭代对象有无限长度,zip_longest() 函数应封装在限制调用次数的场景中(例如 islice()takewhile())。除非指定, fillvalue 默认为 None

itertools 配方

本节将展示如何使用现有的 itertools 作为基础构件来创建扩展的工具集。

The primary purpose of the itertools recipes is educational. The recipes show various ways of thinking about individual tools — for example, that chain.from_iterable is related to the concept of flattening. The recipes also give ideas about ways that the tools can be combined — for example, how compress() and range() can work together. The recipes also show patterns for using itertools with the operator and collections modules as well as with the built-in itertools such as map(), filter(), reversed(), and enumerate().

A secondary purpose of the recipes is to serve as an incubator. The accumulate(), compress(), and pairwise() itertools started out as recipes. Currently, the iter_index() recipe is being tested to see whether it proves its worth.

基本上所有这些西方和许许多多其他的配方都可以通过 Python Package Index 上的 more-itertools 项目 来安装:

  1. python -m pip install more-itertools

Many of the recipes offer the same high performance as the underlying toolset. Superior memory performance is kept by processing elements one at a time rather than bringing the whole iterable into memory all at once. Code volume is kept small by linking the tools together in a functional style which helps eliminate temporary variables. High speed is retained by preferring “vectorized” building blocks over the use of for-loops and generators which incur interpreter overhead.

  1. def take(n, iterable):
  2. "Return first n items of the iterable as a list"
  3. return list(islice(iterable, n))
  4. def prepend(value, iterator):
  5. "Prepend a single value in front of an iterator"
  6. # prepend(1, [2, 3, 4]) --> 1 2 3 4
  7. return chain([value], iterator)
  8. def tabulate(function, start=0):
  9. "Return function(0), function(1), ..."
  10. return map(function, count(start))
  11. def tail(n, iterable):
  12. "Return an iterator over the last n items"
  13. # tail(3, 'ABCDEFG') --> E F G
  14. return iter(collections.deque(iterable, maxlen=n))
  15. def consume(iterator, n=None):
  16. "Advance the iterator n-steps ahead. If n is None, consume entirely."
  17. # Use functions that consume iterators at C speed.
  18. if n is None:
  19. # feed the entire iterator into a zero-length deque
  20. collections.deque(iterator, maxlen=0)
  21. else:
  22. # advance to the empty slice starting at position n
  23. next(islice(iterator, n, n), None)
  24. def nth(iterable, n, default=None):
  25. "Returns the nth item or a default value"
  26. return next(islice(iterable, n, None), default)
  27. def all_equal(iterable):
  28. "Returns True if all the elements are equal to each other"
  29. g = groupby(iterable)
  30. return next(g, True) and not next(g, False)
  31. def quantify(iterable, pred=bool):
  32. "Count how many times the predicate is true"
  33. return sum(map(pred, iterable))
  34. def pad_none(iterable):
  35. "Returns the sequence elements and then returns None indefinitely."
  36. return chain(iterable, repeat(None))
  37. def ncycles(iterable, n):
  38. "Returns the sequence elements n times"
  39. return chain.from_iterable(repeat(tuple(iterable), n))
  40. def dotproduct(vec1, vec2):
  41. return sum(map(operator.mul, vec1, vec2))
  42. def convolve(signal, kernel):
  43. # See: https://betterexplained.com/articles/intuitive-convolution/
  44. # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
  45. # convolve(data, [1, -1]) --> 1st finite difference (1st derivative)
  46. # convolve(data, [1, -2, 1]) --> 2nd finite difference (2nd derivative)
  47. kernel = tuple(kernel)[::-1]
  48. n = len(kernel)
  49. window = collections.deque([0], maxlen=n) * n
  50. for x in chain(signal, repeat(0, n-1)):
  51. window.append(x)
  52. yield sum(map(operator.mul, kernel, window))
  53. def polynomial_from_roots(roots):
  54. """Compute a polynomial's coefficients from its roots.
  55. (x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
  56. """
  57. # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60]
  58. roots = list(map(operator.neg, roots))
  59. return [
  60. sum(map(math.prod, combinations(roots, k)))
  61. for k in range(len(roots) + 1)
  62. ]
  63. def iter_index(iterable, value, start=0):
  64. "Return indices where a value occurs in a sequence or iterable."
  65. # iter_index('AABCADEAF', 'A') --> 0 1 4 7
  66. try:
  67. seq_index = iterable.index
  68. except AttributeError:
  69. # Slow path for general iterables
  70. it = islice(iterable, start, None)
  71. for i, element in enumerate(it, start):
  72. if element is value or element == value:
  73. yield i
  74. else:
  75. # Fast path for sequences
  76. i = start - 1
  77. try:
  78. while True:
  79. yield (i := seq_index(value, i+1))
  80. except ValueError:
  81. pass
  82. def sieve(n):
  83. "Primes less than n"
  84. # sieve(30) --> 2 3 5 7 11 13 17 19 23 29
  85. data = bytearray((0, 1)) * (n // 2)
  86. data[:3] = 0, 0, 0
  87. limit = math.isqrt(n) + 1
  88. for p in compress(range(limit), data):
  89. data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
  90. data[2] = 1
  91. return iter_index(data, 1) if n > 2 else iter([])
  92. def flatten(list_of_lists):
  93. "Flatten one level of nesting"
  94. return chain.from_iterable(list_of_lists)
  95. def repeatfunc(func, times=None, *args):
  96. """Repeat calls to func with specified arguments.
  97. Example: repeatfunc(random.random)
  98. """
  99. if times is None:
  100. return starmap(func, repeat(args))
  101. return starmap(func, repeat(args, times))
  102. def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
  103. "Collect data into non-overlapping fixed-length chunks or blocks"
  104. # grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
  105. # grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
  106. # grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
  107. args = [iter(iterable)] * n
  108. if incomplete == 'fill':
  109. return zip_longest(*args, fillvalue=fillvalue)
  110. if incomplete == 'strict':
  111. return zip(*args, strict=True)
  112. if incomplete == 'ignore':
  113. return zip(*args)
  114. else:
  115. raise ValueError('Expected fill, strict, or ignore')
  116. def batched(iterable, n):
  117. "Batch data into lists of length n. The last batch may be shorter."
  118. # batched('ABCDEFG', 3) --> ABC DEF G
  119. if n < 1:
  120. raise ValueError('n must be at least one')
  121. it = iter(iterable)
  122. while (batch := list(islice(it, n))):
  123. yield batch
  124. def triplewise(iterable):
  125. "Return overlapping triplets from an iterable"
  126. # triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
  127. for (a, _), (b, c) in pairwise(pairwise(iterable)):
  128. yield a, b, c
  129. def sliding_window(iterable, n):
  130. # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
  131. it = iter(iterable)
  132. window = collections.deque(islice(it, n), maxlen=n)
  133. if len(window) == n:
  134. yield tuple(window)
  135. for x in it:
  136. window.append(x)
  137. yield tuple(window)
  138. def roundrobin(*iterables):
  139. "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
  140. # Recipe credited to George Sakkis
  141. num_active = len(iterables)
  142. nexts = cycle(iter(it).__next__ for it in iterables)
  143. while num_active:
  144. try:
  145. for next in nexts:
  146. yield next()
  147. except StopIteration:
  148. # Remove the iterator we just exhausted from the cycle.
  149. num_active -= 1
  150. nexts = cycle(islice(nexts, num_active))
  151. def partition(pred, iterable):
  152. "Use a predicate to partition entries into false entries and true entries"
  153. # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
  154. t1, t2 = tee(iterable)
  155. return filterfalse(pred, t1), filter(pred, t2)
  156. def before_and_after(predicate, it):
  157. """ Variant of takewhile() that allows complete
  158. access to the remainder of the iterator.
  159. >>> it = iter('ABCdEfGhI')
  160. >>> all_upper, remainder = before_and_after(str.isupper, it)
  161. >>> ''.join(all_upper)
  162. 'ABC'
  163. >>> ''.join(remainder) # takewhile() would lose the 'd'
  164. 'dEfGhI'
  165. Note that the first iterator must be fully
  166. consumed before the second iterator can
  167. generate valid results.
  168. """
  169. it = iter(it)
  170. transition = []
  171. def true_iterator():
  172. for elem in it:
  173. if predicate(elem):
  174. yield elem
  175. else:
  176. transition.append(elem)
  177. return
  178. def remainder_iterator():
  179. yield from transition
  180. yield from it
  181. return true_iterator(), remainder_iterator()
  182. def subslices(seq):
  183. "Return all contiguous non-empty subslices of a sequence"
  184. # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
  185. slices = starmap(slice, combinations(range(len(seq) + 1), 2))
  186. return map(operator.getitem, repeat(seq), slices)
  187. def powerset(iterable):
  188. "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
  189. s = list(iterable)
  190. return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
  191. def unique_everseen(iterable, key=None):
  192. "List unique elements, preserving order. Remember all elements ever seen."
  193. # unique_everseen('AAAABBBCCDAABBB') --> A B C D
  194. # unique_everseen('ABBCcAD', str.lower) --> A B C D
  195. seen = set()
  196. if key is None:
  197. for element in filterfalse(seen.__contains__, iterable):
  198. seen.add(element)
  199. yield element
  200. # Note: The steps shown above are intended to demonstrate
  201. # filterfalse(). For order preserving deduplication,
  202. # a better solution is:
  203. # yield from dict.fromkeys(iterable)
  204. else:
  205. for element in iterable:
  206. k = key(element)
  207. if k not in seen:
  208. seen.add(k)
  209. yield element
  210. def unique_justseen(iterable, key=None):
  211. "List unique elements, preserving order. Remember only the element just seen."
  212. # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
  213. # unique_justseen('ABBCcAD', str.lower) --> A B C A D
  214. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  215. def iter_except(func, exception, first=None):
  216. """ Call a function repeatedly until an exception is raised.
  217. Converts a call-until-exception interface to an iterator interface.
  218. Like builtins.iter(func, sentinel) but uses an exception instead
  219. of a sentinel to end the loop.
  220. Examples:
  221. iter_except(functools.partial(heappop, h), IndexError) # priority queue iterator
  222. iter_except(d.popitem, KeyError) # non-blocking dict iterator
  223. iter_except(d.popleft, IndexError) # non-blocking deque iterator
  224. iter_except(q.get_nowait, Queue.Empty) # loop over a producer Queue
  225. iter_except(s.pop, KeyError) # non-blocking set iterator
  226. """
  227. try:
  228. if first is not None:
  229. yield first() # For database APIs needing an initial cast to db.first()
  230. while True:
  231. yield func()
  232. except exception:
  233. pass
  234. def first_true(iterable, default=False, pred=None):
  235. """Returns the first true value in the iterable.
  236. If no true value is found, returns *default*
  237. If *pred* is not None, returns the first item
  238. for which pred(item) is true.
  239. """
  240. # first_true([a,b,c], x) --> a or b or c or x
  241. # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
  242. return next(filter(pred, iterable), default)
  243. def nth_combination(iterable, r, index):
  244. "Equivalent to list(combinations(iterable, r))[index]"
  245. pool = tuple(iterable)
  246. n = len(pool)
  247. c = math.comb(n, r)
  248. if index < 0:
  249. index += c
  250. if index < 0 or index >= c:
  251. raise IndexError
  252. result = []
  253. while r:
  254. c, n, r = c*r//n, n-1, r-1
  255. while index >= c:
  256. index -= c
  257. c, n = c*(n-r)//n, n-1
  258. result.append(pool[-1-n])
  259. return tuple(result)