Source code for les_iterables.grouping

from collections import deque
from collections.abc import Sequence
from itertools import tee, zip_longest

[docs]def partition_tail(items, n): """Lazily partition an iterable series into a head, and tail of no more than specified length. Args: items: An iterable series of items. n: The maximum number of items to be partitioned into the tail. Returns: A pair of iterators, head and tail. Consuming any items from the tail iterator will cause the entire head iterator to be consumed, so typically the head iterator should be consumed before consuming any items from the tail iterator. Example: To partition the last three items of an iterable series, do:: head, tail = partition_tail(range(10), 3) for item in head: print(item) # Prints all but the last three for item in tail: print(item) # Prints the last three Raises: ValueError: If n is negative. """ p = PartitionedTail(items, n) h = HeadPartitionIterator(p) t = TailPartitionIterator(p) return h, t
[docs]class PartitionedTail: def __init__(self, items, n): if n < 0: raise ValueError(f"Cannot partition negative number ({n}) items into the tail") self._i = iter(items) self._n = n self._d = deque() self._head_iterator = None self._tail_iterator = None
[docs]class HeadPartitionIterator: def __init__(self, partition_tail): self._pt = partition_tail self._pt._head_iterator = self def __iter__(self): return self def __next__(self): while len(self._pt._d) < self._pt._n: self._pt._d.append(next(self._pt._i)) assert len(self._pt._d) == self._pt._n incoming_item = next(self._pt._i) # If this raises StopIteration, allow it to propagate self._pt._d.append(incoming_item) outgoing_item = self._pt._d.popleft() assert len(self._pt._d) == self._pt._n return outgoing_item
[docs]class TailPartitionIterator: def __init__(self, partition_tail): self._pt = partition_tail self._pt._tail_iterator = self self._consumed_head = False def __iter__(self): return self def __next__(self): if not self._consumed_head: deque(self._pt._head_iterator, maxlen=0) # Consume all items self._consumed_head = True if len(self._pt._d) == 0: raise StopIteration return self._pt._d.popleft()
[docs]def split_around(iterable, predicate, group_factory=None): """Split an iterable series into groups around specific items. Each item for which the predicate returns True will be in its own group. Args: iterable: An iterable series of items to be grouped. predicate: A unary callable to detect items which should be placed in their own group. group_factory: A callable which creates a group given a sequence of items. By default, a list. Yields: A series of groups. """ if group_factory is None: group_factory = lambda x: x group = [] for item in iterable: if predicate(item): if group: yield group_factory(group) group = [] group.append(item) yield group_factory(group) group = [] else: group.append(item) if group: yield group_factory(group)
[docs]def group_by_terminator(iterable, predicate, group_factory=None): """Group the items of of an iterable series, starting a new group after each terminator. Each group will have as it's last item an item from which the predicate returns True. For all preceding items in the group the predicate will return False. The last group yielded may be incomplete, without a terminator. Args: iterable: An iterable series of items to be grouped. predicate: A unary callable function used to detect group-terminating items from the iterable series. group_factory: A callable which creates a group given an sequence of items. By default, a list. Yields: A series of groups. """ if group_factory is None: group_factory = lambda x: x group = [] for item in iterable: group.append(item) if predicate(item): yield group_factory(group) group = [] if group: yield group_factory(group)
[docs]def pairwise_padded(iterable, fillvalue=None): """Each item in an iterable series with its successor. The number of pairs returned will be equal to the number of items. Args: iterable: An iterable series of items to be grouped into pairs. fillvalue: The value used as the successor to the last item. Yields: A series of 2-tuples contain an item and its successor. For the last item the successor will be the fillvalue. """ a, b = tee(iterable) next(b, fillvalue) return zip_longest(a, b, fillvalue=fillvalue)
[docs]def split_after_first(iterable, predicate, group_factory=None): """Split the iterable after the element matching the predicate. Always returns at least 1 group, and no more than 2 groups. If there is no element matching the predicate, the iterable is returned unchanged. If the iterable is empty, returns a single empty group. Examples: :: group, *groups = split_after_first([1, 2, 3, 1, 2, 3], lambda x: x == 2)) assert group == [1, 2] assert groups == [[3, 1, 2, 3]] group, *groups = split_after_first([1, 2, 3], lambda x: x == 3) assert group == [1, 2, 3] assert groups == [] group, *groups = split_after_first('abcde', lambda x: x == 'c') assert group == 'abc' assert groups == ['de'] Returns: An iterable series of groups. """ group_factory = _make_group_factory(iterable, group_factory) group = [] iterator = iter(iterable) for item in iterator: group.append(item) if predicate(item): break yield group_factory(group) remainder = list(iterator) if remainder: yield group_factory(remainder)
[docs]def partition(iterable, predicate, group_factory=None): group_factory = _make_group_factory(iterable, group_factory) before = [] separator = [] iterator = iter(iterable) for item in iterator: if not predicate(item): before.append(item) else: separator.append(item) break after = list(iterator) return ( group_factory(before), group_factory(separator), group_factory(after), )
def _make_group_factory(iterable, group_factory=None): if group_factory is None: if isinstance(iterable, str): group_factory = lambda s: ''.join(s) elif isinstance(iterable, list): group_factory = lambda s: s elif isinstance(iterable, Sequence): group_factory = type(iterable) else: group_factory = lambda s: s return group_factory