用python实现基本数据结构和算法

1章:ADT抽象数据类型,定义数据和其操作

什么是ADT: 抽象数据类型,学过数据结构的应该都知道。

How to select datastructures for ADT

  1. Dose the data structure provide for the storage requirements as specified by the domain of the ADT?
  2. Does the data structure provide the data access and manipulation functionality to fully implement the ADT?
  3. Effcient implemention? based on complexity analysis.

下边代码是个简单的示例,比如实现一个简单的Bag类,先定义其具有的操作,然后我们再用类的magic method来实现这些方法:

  1. class Bag:
  2. """
  3. constructor: 构造函数
  4. size
  5. contains
  6. append
  7. remove
  8. iter
  9. """
  10. def __init__(self):
  11. self._items = list()
  12. def __len__(self):
  13. return len(self._items)
  14. def __contains__(self, item):
  15. return item in self._items
  16. def add(self, item):
  17. self._items.append(item)
  18. def remove(self, item):
  19. assert item in self._items, 'item must in the bag'
  20. return self._items.remove(item)
  21. def __iter__(self):
  22. return _BagIterator(self._items)
  23. class _BagIterator:
  24. """ 注意这里实现了迭代器类 """
  25. def __init__(self, seq):
  26. self._bag_items = seq
  27. self._cur_item = 0
  28. def __iter__(self):
  29. return self
  30. def __next__(self):
  31. if self._cur_item < len(self._bag_items):
  32. item = self._bag_items[self._cur_item]
  33. self._cur_item += 1
  34. return item
  35. else:
  36. raise StopIteration
  37. b = Bag()
  38. b.add(1)
  39. b.add(2)
  40. for i in b: # for使用__iter__构建,用__next__迭代
  41. print(i)
  42. """
  43. # for 语句等价于
  44. i = b.__iter__()
  45. while True:
  46. try:
  47. item = i.__next__()
  48. print(item)
  49. except StopIteration:
  50. break
  51. """

2章:array vs list

array: 定长,操作有限,但是节省内存;貌似我的生涯中还没用过,不过python3.5中我试了确实有array类,可以用import array直接导入

list: 会预先分配内存,操作丰富,但是耗费内存。我用sys.getsizeof做了实验。我个人理解很类似C++ STL里的vector,是使用最频繁的数据结构。

  • list.append: 如果之前没有分配够内存,会重新开辟新区域,然后复制之前的数据,复杂度退化
  • list.insert: 会移动被插入区域后所有元素,O(n)
  • list.pop: pop不同位置需要的复杂度不同pop(0)是O(1)复杂度,pop()首位O(n)复杂度
  • list[]: slice操作copy数据(预留空间)到另一个list

来实现一个array的ADT:

  1. import ctypes
  2. class Array:
  3. def __init__(self, size):
  4. assert size > 0, 'array size must be > 0'
  5. self._size = size
  6. PyArrayType = ctypes.py_object * size
  7. self._elements = PyArrayType()
  8. self.clear(None)
  9. def __len__(self):
  10. return self._size
  11. def __getitem__(self, index):
  12. assert index >= 0 and index < len(self), 'out of range'
  13. return self._elements[index]
  14. def __setitem__(self, index, value):
  15. assert index >= 0 and index < len(self), 'out of range'
  16. self._elements[index] = value
  17. def clear(self, value):
  18. """ 设置每个元素为value """
  19. for i in range(len(self)):
  20. self._elements[i] = value
  21. def __iter__(self):
  22. return _ArrayIterator(self._elements)
  23. class _ArrayIterator:
  24. def __init__(self, items):
  25. self._items = items
  26. self._idx = 0
  27. def __iter__(self):
  28. return self
  29. def __next__(self):
  30. if self._idex < len(self._items):
  31. val = self._items[self._idx]
  32. self._idex += 1
  33. return val
  34. else:
  35. raise StopIteration

Two-Demensional Arrays

  1. class Array2D:
  2. """ 要实现的方法
  3. Array2D(nrows, ncols): constructor
  4. numRows()
  5. numCols()
  6. clear(value)
  7. getitem(i, j)
  8. setitem(i, j, val)
  9. """
  10. def __init__(self, numrows, numcols):
  11. self._the_rows = Array(numrows) # 数组的数组
  12. for i in range(numrows):
  13. self._the_rows[i] = Array(numcols)
  14. @property
  15. def numRows(self):
  16. return len(self._the_rows)
  17. @property
  18. def NumCols(self):
  19. return len(self._the_rows[0])
  20. def clear(self, value):
  21. for row in self._the_rows:
  22. row.clear(value)
  23. def __getitem__(self, ndx_tuple): # ndx_tuple: (x, y)
  24. assert len(ndx_tuple) == 2
  25. row, col = ndx_tuple[0], ndx_tuple[1]
  26. assert (row >= 0 and row < self.numRows and
  27. col >= 0 and col < self.NumCols)
  28. the_1d_array = self._the_rows[row]
  29. return the_1d_array[col]
  30. def __setitem__(self, ndx_tuple, value):
  31. assert len(ndx_tuple) == 2
  32. row, col = ndx_tuple[0], ndx_tuple[1]
  33. assert (row >= 0 and row < self.numRows and
  34. col >= 0 and col < self.NumCols)
  35. the_1d_array = self._the_rows[row]
  36. the_1d_array[col] = value

The Matrix ADT, m行,n列。这个最好用还是用pandas处理矩阵,自己实现比较*疼

  1. class Matrix:
  2. """ 最好用pandas的DataFrame
  3. Matrix(rows, ncols): constructor
  4. numCols()
  5. getitem(row, col)
  6. setitem(row, col, val)
  7. scaleBy(scalar): 每个元素乘scalar
  8. transpose(): 返回transpose转置
  9. add(rhsMatrix): size must be the same
  10. subtract(rhsMatrix)
  11. multiply(rhsMatrix)
  12. """
  13. def __init__(self, numRows, numCols):
  14. self._theGrid = Array2D(numRows, numCols)
  15. self._theGrid.clear(0)
  16. @property
  17. def numRows(self):
  18. return self._theGrid.numRows
  19. @property
  20. def NumCols(self):
  21. return self._theGrid.numCols
  22. def __getitem__(self, ndxTuple):
  23. return self._theGrid[ndxTuple[0], ndxTuple[1]]
  24. def __setitem__(self, ndxTuple, scalar):
  25. self._theGrid[ndxTuple[0], ndxTuple[1]] = scalar
  26. def scaleBy(self, scalar):
  27. for r in range(self.numRows):
  28. for c in range(self.numCols):
  29. self[r, c] *= scalar
  30. def __add__(self, rhsMatrix):
  31. assert (rhsMatrix.numRows == self.numRows and
  32. rhsMatrix.numCols == self.numCols)
  33. newMartrix = Matrix(self.numRows, self.numCols)
  34. for r in range(self.numRows):
  35. for c in range(self.numCols):
  36. newMartrix[r, c] = self[r, c] + rhsMatrix[r, c]

3章:Sets and Maps

除了list之外,最常用的应该就是python内置的set和dict了。

sets ADT

A set is a container that stores a collection of unique values over a given comparable domain in which the stored values have no particular ordering.

  1. class Set:
  2. """ 使用list实现set ADT
  3. Set()
  4. length()
  5. contains(element)
  6. add(element)
  7. remove(element)
  8. equals(element)
  9. isSubsetOf(setB)
  10. union(setB)
  11. intersect(setB)
  12. difference(setB)
  13. iterator()
  14. """
  15. def __init__(self):
  16. self._theElements = list()
  17. def __len__(self):
  18. return len(self._theElements)
  19. def __contains__(self, element):
  20. return element in self._theElements
  21. def add(self, element):
  22. if element not in self:
  23. self._theElements.append(element)
  24. def remove(self, element):
  25. assert element in self, 'The element must be set'
  26. self._theElements.remove(element)
  27. def __eq__(self, setB):
  28. if len(self) != len(setB):
  29. return False
  30. else:
  31. return self.isSubsetOf(setB)
  32. def isSubsetOf(self, setB):
  33. for element in self:
  34. if element not in setB:
  35. return False
  36. return True
  37. def union(self, setB):
  38. newSet = Set()
  39. newSet._theElements.extend(self._theElements)
  40. for element in setB:
  41. if element not in self:
  42. newSet._theElements.append(element)
  43. return newSet

Maps or Dict: 键值对,python内部采用hash实现。

  1. class Map:
  2. """ Map ADT list implemention
  3. Map()
  4. length()
  5. contains(key)
  6. add(key, value)
  7. remove(key)
  8. valudOf(key)
  9. iterator()
  10. """
  11. def __init__(self):
  12. self._entryList = list()
  13. def __len__(self):
  14. return len(self._entryList)
  15. def __contains__(self, key):
  16. ndx = self._findPosition(key)
  17. return ndx is not None
  18. def add(self, key, value):
  19. ndx = self._findPosition(key)
  20. if ndx is not None:
  21. self._entryList[ndx].value = value
  22. return False
  23. else:
  24. entry = _MapEntry(key, value)
  25. self._entryList.append(entry)
  26. return True
  27. def valueOf(self, key):
  28. ndx = self._findPosition(key)
  29. assert ndx is not None, 'Invalid map key'
  30. return self._entryList[ndx].value
  31. def remove(self, key):
  32. ndx = self._findPosition(key)
  33. assert ndx is not None, 'Invalid map key'
  34. self._entryList.pop(ndx)
  35. def __iter__(self):
  36. return _MapIterator(self._entryList)
  37. def _findPosition(self, key):
  38. for i in range(len(self)):
  39. if self._entryList[i].key == key:
  40. return i
  41. return None
  42. class _MapEntry: # or use collections.namedtuple('_MapEntry', 'key,value')
  43. def __init__(self, key, value):
  44. self.key = key
  45. self.value = value

The multiArray ADT, 多维数组,一般是使用一个一维数组模拟,然后通过计算下标获取元素

  1. class MultiArray:
  2. """ row-major or column-marjor ordering, this is row-major ordering
  3. MultiArray(d1, d2, ...dn)
  4. dims(): the number of dimensions
  5. length(dim): the length of given array dimension
  6. clear(value)
  7. getitem(i1, i2, ... in), index(i1,i2,i3) = i1*(d2*d3) + i2*d3 + i3
  8. setitem(i1, i2, ... in)
  9. 计算下标:index(i1,i2,...in) = i1*f1 + i2*f2 + ... + i(n-1)*f(n-1) + in*1
  10. """
  11. def __init__(self, *dimensions):
  12. # Implementation of MultiArray ADT using a 1-D # array,数组的数组的数组。。。
  13. assert len(dimensions) > 1, 'The array must have 2 or more dimensions'
  14. self._dims = dimensions
  15. # Compute to total number of elements in the array
  16. size = 1
  17. for d in dimensions:
  18. assert d > 0, 'Dimensions must be > 0'
  19. size *= d
  20. # Create the 1-D array to store the elements
  21. self._elements = Array(size)
  22. # Create a 1-D array to store the equation factors
  23. self._factors = Array(len(dimensions))
  24. self._computeFactors()
  25. @property
  26. def numDims(self):
  27. return len(self._dims)
  28. def length(self, dim):
  29. assert dim > 0 and dim < len(self._dims), 'Dimension component out of range'
  30. return self._dims[dim-1]
  31. def clear(self, value):
  32. self._elements.clear(value)
  33. def __getitem__(self, ndxTuple):
  34. assert len(ndxTuple) == self.numDims, 'Invalid # of array subscripts'
  35. index = self._computeIndex(ndxTuple)
  36. assert index is not None, 'Array subscript out of range'
  37. return self._elements[index]
  38. def __setitem__(self, ndxTuple, value):
  39. assert len(ndxTuple) == self.numDims, 'Invalid # of array subscripts'
  40. index = self._computeIndex(ndxTuple)
  41. assert index is not None, 'Array subscript out of range'
  42. self._elements[index] = value
  43. def _computeIndex(self, ndxTuple):
  44. # using the equation: i1*f1 + i2*f2 + ... + in*fn
  45. offset = 0
  46. for j in range(len(ndxTuple)):
  47. if ndxTuple[j] < 0 or ndxTuple[j] >= self._dims[j]:
  48. return None
  49. else:
  50. offset += ndexTuple[j] * self._factors[j]
  51. return offset

4章:Algorithm Analysis

一般使用大O标记法来衡量算法的平均时间复杂度, 1 < log(n) < n < nlog(n) < n^2 < n^3 < a^n。 了解常用数据结构操作的平均时间复杂度有利于使用更高效的数据结构,当然有时候需要在时间和空间上进行衡量,有些操作甚至还会退化,比如list的append操作,如果list空间不够,会去开辟新的空间,操作复杂度退化到O(n),有时候还需要使用均摊分析(amortized)


5章:Searching and Sorting

排序和查找是最基础和频繁的操作,python内置了in操作符和bisect二分操作模块实现查找,内置了sorted方法来实现排序操作。二分和快排也是面试中经常考到的,本章讲的是基本的排序和查找。

  1. def binary_search(sorted_seq, val):
  2. """ 实现标准库中的bisect.bisect_left """
  3. low = 0
  4. high = len(sorted_seq) - 1
  5. while low <= high:
  6. mid = (high + low) // 2
  7. if sorted_seq[mid] == val:
  8. return mid
  9. elif val < sorted_seq[mid]:
  10. high = mid - 1
  11. else:
  12. low = mid + 1
  13. return low
  14. def bubble_sort(seq): # O(n^2), n(n-1)/2 = 1/2(n^2 + n)
  15. n = len(seq)
  16. for i in range(n-1):
  17. for j in range(n-1-i): # 这里之所以 n-1 还需要 减去 i 是因为每一轮冒泡最大的元素都会冒泡到最后,无需再比较
  18. if seq[j] > seq[j+1]:
  19. seq[j], seq[j+1] = seq[j+1], seq[j]
  20. def select_sort(seq):
  21. """可以看作是冒泡的改进,每次找一个最小的元素交换,每一轮只需要交换一次"""
  22. n = len(seq)
  23. for i in range(n-1):
  24. min_idx = i # assume the ith element is the smallest
  25. for j in range(i+1, n):
  26. if seq[j] < seq[min_idx]: # find the minist element index
  27. min_idx = j
  28. if min_idx != i: # swap
  29. seq[i], seq[min_idx] = seq[min_idx], seq[i]
  30. def insertion_sort(seq):
  31. """ 每次挑选下一个元素插入已经排序的数组中,初始时已排序数组只有一个元素"""
  32. n = len(seq)
  33. for i in range(1, n):
  34. value = seq[i] # save the value to be positioned
  35. # find the position where value fits in the ordered part of the list
  36. pos = i
  37. while pos > 0 and value < seq[pos-1]:
  38. # Shift the items to the right during the search
  39. seq[pos] = seq[pos-1]
  40. pos -= 1
  41. seq[pos] = value
  42. def merge_sorted_list(listA, listB):
  43. """ 归并两个有序数组 """
  44. new_list = list()
  45. a = b = 0
  46. while a < len(listA) and b < len(listB):
  47. if listA[a] < listB[b]:
  48. new_list.append(listA[a])
  49. a += 1
  50. else:
  51. new_list.append(listB[b])
  52. b += 1
  53. while a < len(listA):
  54. new_list.append(listA[a])
  55. a += 1
  56. while b < len(listB):
  57. new_list.append(listB[b])
  58. b += 1
  59. return new_list

6章: Linked Structure

list是最常用的数据结构,但是list在中间增减元素的时候效率会很低,这时候linked list会更适合,缺点就是获取元素的平均时间复杂度变成了O(n)

  1. # 单链表实现
  2. class ListNode:
  3. def __init__(self, data):
  4. self.data = data
  5. self.next = None
  6. def travsersal(head, callback):
  7. curNode = head
  8. while curNode is not None:
  9. callback(curNode.data)
  10. curNode = curNode.next
  11. def unorderdSearch(head, target):
  12. curNode = head
  13. while curNode is not None and curNode.data != target:
  14. curNode = curNode.next
  15. return curNode is not None
  16. # Given the head pointer, prepend an item to an unsorted linked list.
  17. def prepend(head, item):
  18. newNode = ListNode(item)
  19. newNode.next = head
  20. head = newNode
  21. # Given the head reference, remove a target from a linked list
  22. def remove(head, target):
  23. predNode = None
  24. curNode = head
  25. while curNode is not None and curNode.data != target:
  26. # 寻找目标
  27. predNode = curNode
  28. curNode = curNode.data
  29. if curNode is not None:
  30. if curNode is head:
  31. head = curNode.next
  32. else:
  33. predNode.next = curNode.next

7章:Stacks

栈也是计算机里用得比较多的数据结构,栈是一种后进先出的数据结构,可以理解为往一个桶里放盘子,先放进去的会被压在地下,拿盘子的时候,后放的会被先拿出来。

  1. class Stack:
  2. """ Stack ADT, using a python list
  3. Stack()
  4. isEmpty()
  5. length()
  6. pop(): assert not empty
  7. peek(): assert not empty, return top of non-empty stack without removing it
  8. push(item)
  9. """
  10. def __init__(self):
  11. self._items = list()
  12. def isEmpty(self):
  13. return len(self) == 0
  14. def __len__(self):
  15. return len(self._items)
  16. def peek(self):
  17. assert not self.isEmpty()
  18. return self._items[-1]
  19. def pop(self):
  20. assert not self.isEmpty()
  21. return self._items.pop()
  22. def push(self, item):
  23. self._items.append(item)
  24. class Stack:
  25. """ Stack ADT, use linked list
  26. 使用list实现很简单,但是如果涉及大量push操作,list的空间不够时复杂度退化到O(n)
  27. 而linked list可以保证最坏情况下仍是O(1)
  28. """
  29. def __init__(self):
  30. self._top = None # top节点, _StackNode or None
  31. self._size = 0 # int
  32. def isEmpty(self):
  33. return self._top is None
  34. def __len__(self):
  35. return self._size
  36. def peek(self):
  37. assert not self.isEmpty()
  38. return self._top.item
  39. def pop(self):
  40. assert not self.isEmpty()
  41. node = self._top
  42. self.top = self._top.next
  43. self._size -= 1
  44. return node.item
  45. def _push(self, item):
  46. self._top = _StackNode(item, self._top)
  47. self._size += 1
  48. class _StackNode:
  49. def __init__(self, item, link):
  50. self.item = item
  51. self.next = link

8章:Queues

队列也是经常使用的数据结构,比如发送消息等,celery可以使用redis提供的list实现消息队列。 本章我们用list和linked list来实现队列和优先级队列。

  1. class Queue:
  2. """ Queue ADT, use list。list实现,简单但是push和pop效率最差是O(n)
  3. Queue()
  4. isEmpty()
  5. length()
  6. enqueue(item)
  7. dequeue()
  8. """
  9. def __init__(self):
  10. self._qList = list()
  11. def isEmpty(self):
  12. return len(self) == 0
  13. def __len__(self):
  14. return len(self._qList)
  15. def enquue(self, item):
  16. self._qList.append(item)
  17. def dequeue(self):
  18. assert not self.isEmpty()
  19. return self._qList.pop(0)
  20. from array import Array # Array那一章实现的Array ADT
  21. class Queue:
  22. """
  23. circular Array ,通过头尾指针实现。list内置append和pop复杂度会退化,使用
  24. 环数组实现可以使得入队出队操作时间复杂度为O(1),缺点是数组长度需要固定。
  25. """
  26. def __init__(self, maxSize):
  27. self._count = 0
  28. self._front = 0
  29. self._back = maxSize - 1
  30. self._qArray = Array(maxSize)
  31. def isEmpty(self):
  32. return self._count == 0
  33. def isFull(self):
  34. return self._count == len(self._qArray)
  35. def __len__(self):
  36. return len(self._count)
  37. def enqueue(self, item):
  38. assert not self.isFull()
  39. maxSize = len(self._qArray)
  40. self._back = (self._back + 1) % maxSize # 移动尾指针
  41. self._qArray[self._back] = item
  42. self._count += 1
  43. def dequeue(self):
  44. assert not self.isFull()
  45. item = self._qArray[self._front]
  46. maxSize = len(self._qArray)
  47. self._front = (self._front + 1) % maxSize
  48. self._count -= 1
  49. return item
  50. class _QueueNode:
  51. def __init__(self, item):
  52. self.item = item
  53. class Queue:
  54. """ Queue ADT, linked list 实现。为了改进环型数组有最大数量的限制,改用
  55. 带有头尾节点的linked list实现。
  56. """
  57. def __init__(self):
  58. self._qhead = None
  59. self._qtail = None
  60. self._qsize = 0
  61. def isEmpty(self):
  62. return self._qhead is None
  63. def __len__(self):
  64. return self._count
  65. def enqueue(self, item):
  66. node = _QueueNode(item) # 创建新的节点并用尾节点指向他
  67. if self.isEmpty():
  68. self._qhead = node
  69. else:
  70. self._qtail.next = node
  71. self._qtail = node
  72. self._qcount += 1
  73. def dequeue(self):
  74. assert not self.isEmpty(), 'Can not dequeue from an empty queue'
  75. node = self._qhead
  76. if self._qhead is self._qtail:
  77. self._qtail = None
  78. self._qhead = self._qhead.next # 前移头节点
  79. self._count -= 1
  80. return node.item
  81. class UnboundedPriorityQueue:
  82. """ PriorityQueue ADT: 给每个item加上优先级p,高优先级先dequeue
  83. 分为两种:
  84. - bounded PriorityQueue: 限制优先级在一个区间[0...p)
  85. - unbounded PriorityQueue: 不限制优先级
  86. PriorityQueue()
  87. BPriorityQueue(numLevels): create a bounded PriorityQueue with priority in range
  88. [0, numLevels-1]
  89. isEmpty()
  90. length()
  91. enqueue(item, priority): 如果是bounded PriorityQueue, priority必须在区间内
  92. dequeue(): 最高优先级的出队,同优先级的按照FIFO顺序
  93. - 两种实现方式:
  94. 1.入队的时候都是到队尾,出队操作找到最高优先级的出队,出队操作O(n)
  95. 2.始终维持队列有序,每次入队都找到该插入的位置,出队操作是O(1)
  96. (注意如果用list实现list.append和pop操作复杂度会因内存分配退化)
  97. """
  98. from collections import namedtuple
  99. _PriorityQEntry = namedtuple('_PriorityQEntry', 'item, priority')
  100. # 采用方式1,用内置list实现unbounded PriorityQueue
  101. def __init__(self):
  102. self._qlist = list()
  103. def isEmpty(self):
  104. return len(self) == 0
  105. def __len__(self):
  106. return len(self._qlist)
  107. def enqueue(self, item, priority):
  108. entry = UnboundedPriorityQueue._PriorityQEntry(item, priority)
  109. self._qlist.append(entry)
  110. def deque(self):
  111. assert not self.isEmpty(), 'can not deque from an empty queue'
  112. highest = self._qlist[0].priority
  113. for i in range(len(self)): # 出队操作O(n),遍历找到最高优先级
  114. if self._qlist[i].priority < highest:
  115. highest = self._qlist[i].priority
  116. entry = self._qlist.pop(highest)
  117. return entry.item
  118. class BoundedPriorityQueue:
  119. """ BoundedPriorityQueue ADT,用linked list实现。上一个地方提到了 BoundedPriorityQueue
  120. 但是为什么需要 BoundedPriorityQueue呢? BoundedPriorityQueue 的优先级限制在[0, maxPriority-1]
  121. 对于 UnboundedPriorityQueue,出队操作由于要遍历寻找优先级最高的item,所以平均
  122. 是O(n)的操作,但是对于 BoundedPriorityQueue,用队列数组实现可以达到常量时间,
  123. 用空间换时间。比如要弹出一个元素,直接找到第一个非空队列弹出 元素就可以了。
  124. (小数字代表高优先级,先出队)
  125. qlist
  126. [0] -> ["white"]
  127. [1]
  128. [2] -> ["black", "green"]
  129. [3] -> ["purple", "yellow"]
  130. """
  131. # Implementation of the bounded Priority Queue ADT using an array of #
  132. # queues in which the queues are implemented using a linked list.
  133. from array import Array # 第二章定义的ADT
  134. def __init__(self, numLevels):
  135. self._qSize = 0
  136. self._qLevels = Array(numLevels)
  137. for i in range(numLevels):
  138. self._qLevels[i] = Queue() # 上一节讲到用linked list实现的Queue
  139. def isEmpty(self):
  140. return len(self) == 0
  141. def __len__(self):
  142. return len(self._qSize)
  143. def enqueue(self, item, priority):
  144. assert priority >= 0 and priority < len(self._qLevels), 'invalid priority'
  145. self._qLevel[priority].enquue(item) # 直接找到 priority 对应的槽入队
  146. def deque(self):
  147. assert not self.isEmpty(), 'can not deque from an empty queue'
  148. i = 0
  149. p = len(self._qLevels)
  150. while i < p and not self._qLevels[i].isEmpty(): # 找到第一个非空队列
  151. i += 1
  152. return self._qLevels[i].dequeue()

9章:Advanced Linked Lists

之前曾经介绍过单链表,一个链表节点只有data和next字段,本章介绍高级的链表。

Doubly Linked List,双链表,每个节点多了个prev指向前一个节点。双链表可以用来编写文本编辑器的buffer。

  1. class DListNode:
  2. def __init__(self, data):
  3. self.data = data
  4. self.prev = None
  5. self.next = None
  6. def revTraversa(tail):
  7. curNode = tail
  8. while cruNode is not None:
  9. print(curNode.data)
  10. curNode = curNode.prev
  11. def search_sorted_doubly_linked_list(head, tail, probe, target):
  12. """ probing technique探查法,改进直接遍历,不过最坏时间复杂度仍是O(n)
  13. searching a sorted doubly linked list using the probing technique
  14. Args:
  15. head (DListNode obj)
  16. tail (DListNode obj)
  17. probe (DListNode or None)
  18. target (DListNode.data): data to search
  19. """
  20. if head is None: # make sure list is not empty
  21. return False
  22. if probe is None: # if probe is null, initialize it to first node
  23. probe = head
  24. # if the target comes before the probe node, we traverse backward, otherwise
  25. # traverse forward
  26. if target < probe.data:
  27. while probe is not None and target <= probe.data:
  28. if target == probe.dta:
  29. return True
  30. else:
  31. probe = probe.prev
  32. else:
  33. while probe is not None and target >= probe.data:
  34. if target == probe.data:
  35. return True
  36. else:
  37. probe = probe.next
  38. return False
  39. def insert_node_into_ordered_doubly_linekd_list(value):
  40. """ 最好画个图看,链表操作很容易绕晕,注意赋值顺序"""
  41. newnode = DListNode(value)
  42. if head is None: # empty list
  43. head = newnode
  44. tail = head
  45. elif value < head.data: # insert before head
  46. newnode.next = head
  47. head.prev = newnode
  48. head = newnode
  49. elif value > tail.data: # insert after tail
  50. newnode.prev = tail
  51. tail.next = newnode
  52. tail = newnode
  53. else: # insert into middle
  54. node = head
  55. while node is not None and node.data < value:
  56. node = node.next
  57. newnode.next = node
  58. newnode.prev = node.prev
  59. node.prev.next = newnode
  60. node.prev = newnode

循环链表

  1. def travrseCircularList(listRef):
  2. curNode = listRef
  3. done = listRef is None
  4. while not None:
  5. curNode = curNode.next
  6. print(curNode.data)
  7. done = curNode is listRef # 回到遍历起始点
  8. def searchCircularList(listRef, target):
  9. curNode = listRef
  10. done = listRef is None
  11. while not done:
  12. curNode = curNode.next
  13. if curNode.data == target:
  14. return True
  15. else:
  16. done = curNode is listRef or curNode.data > target
  17. return False
  18. def add_newnode_into_ordered_circular_linked_list(listRef, value):
  19. """ 插入并维持顺序
  20. 1.插入空链表;2.插入头部;3.插入尾部;4.按顺序插入中间
  21. """
  22. newnode = ListNode(value)
  23. if listRef is None: # empty list
  24. listRef = newnode
  25. newnode.next = newnode
  26. elif value < listRef.next.data: # insert in front
  27. newnode.next = listRef.next
  28. listRef.next = newnode
  29. elif value > listRef.data: # insert in back
  30. newnode.next = listRef.next
  31. listRef.next = newnode
  32. listRef = newnode
  33. else: # insert in the middle
  34. preNode = None
  35. curNode = listRef
  36. done = listRef is None
  37. while not done:
  38. preNode = curNode
  39. preNode = curNode.next
  40. done = curNode is listRef or curNode.data > value
  41. newnode.next = curNode
  42. preNode.next = newnode

利用循环双端链表我们可以实现一个经典的缓存失效算法,lru:

  1. # -*- coding: utf-8 -*-
  2. class Node(object):
  3. def __init__(self, prev=None, next=None, key=None, value=None):
  4. self.prev, self.next, self.key, self.value = prev, next, key, value
  5. class CircularDoubleLinkedList(object):
  6. def __init__(self):
  7. node = Node()
  8. node.prev, node.next = node, node
  9. self.rootnode = node
  10. def headnode(self):
  11. return self.rootnode.next
  12. def tailnode(self):
  13. return self.rootnode.prev
  14. def remove(self, node):
  15. if node is self.rootnode:
  16. return
  17. else:
  18. node.prev.next = node.next
  19. node.next.prev = node.prev
  20. def append(self, node):
  21. tailnode = self.tailnode()
  22. tailnode.next = node
  23. node.next = self.rootnode
  24. self.rootnode.prev = node
  25. class LRUCache(object):
  26. def __init__(self, maxsize=16):
  27. self.maxsize = maxsize
  28. self.cache = {}
  29. self.access = CircularDoubleLinkedList()
  30. self.isfull = len(self.cache) >= self.maxsize
  31. def __call__(self, func):
  32. def wrapper(n):
  33. cachenode = self.cache.get(n)
  34. if cachenode is not None: # hit
  35. self.access.remove(cachenode)
  36. self.access.append(cachenode)
  37. return cachenode.value
  38. else: # miss
  39. value = func(n)
  40. if not self.isfull:
  41. tailnode = self.access.tailnode()
  42. newnode = Node(tailnode, self.access.rootnode, n, value)
  43. self.access.append(newnode)
  44. self.cache[n] = newnode
  45. self.isfull = len(self.cache) >= self.maxsize
  46. return value
  47. else: # full
  48. lru_node = self.access.headnode()
  49. del self.cache[lru_node.key]
  50. self.access.remove(lru_node)
  51. tailnode = self.access.tailnode()
  52. newnode = Node(tailnode, self.access.rootnode, n, value)
  53. self.access.append(newnode)
  54. self.cache[n] = newnode
  55. return value
  56. return wrapper
  57. @LRUCache()
  58. def fib(n):
  59. if n <= 2:
  60. return 1
  61. else:
  62. return fib(n - 1) + fib(n - 2)
  63. for i in range(1, 35):
  64. print(fib(i))

10章:Recursion

Recursion is a process for solving problems by subdividing a larger problem into smaller cases of the problem itself and then solving the smaller, more trivial parts.

递归函数:调用自己的函数

  1. # 递归函数:调用自己的函数,看一个最简单的递归函数,倒序打印一个数
  2. def printRev(n):
  3. if n > 0:
  4. print(n)
  5. printRev(n-1)
  6. printRev(3) # 从10输出到1
  7. # 稍微改一下,print放在最后就得到了正序打印的函数
  8. def printInOrder(n):
  9. if n > 0:
  10. printInOrder(n-1)
  11. print(n) # 之所以最小的先打印是因为函数一直递归到n==1时候的最深栈,此时不再
  12. # 递归,开始执行print语句,这时候n==1,之后每跳出一层栈,打印更大的值
  13. printInOrder(3) # 正序输出

Properties of Recursion: 使用stack解决的问题都能用递归解决

  • A recursive solution must contain a base case; 递归出口,代表最小子问题(n == 0退出打印)
  • A recursive solution must contain a recursive case; 可以分解的子问题
  • A recursive solution must make progress toward the base case. 递减n使得n像递归出口靠近

Tail Recursion: occurs when a function includes a single recursive call as the last statement of the function. In this case, a stack is not needed to store values to te used upon the return of the recursive call and thus a solution can be implemented using a iterative loop instead.

  1. # Recursive Binary Search
  2. def recBinarySearch(target, theSeq, first, last):
  3. # 你可以写写单元测试来验证这个函数的正确性
  4. if first > last: # 递归出口1
  5. return False
  6. else:
  7. mid = (first + last) // 2
  8. if theSeq[mid] == target:
  9. return True # 递归出口2
  10. elif theSeq[mid] > target:
  11. return recBinarySearch(target, theSeq, first, mid - 1)
  12. else:
  13. return recBinarySearch(target, theSeq, mid + 1, last)

11章:Hash Tables

基于比较的搜索(线性搜索,有序数组的二分搜索)最好的时间复杂度只能达到O(logn),利用hash可以实现O(1)查找,python内置dict的实现方式就是hash,你会发现dict的key必须要是实现了 __hash____eq__ 方法的。

Hashing: hashing is the process of mapping a search a key to a limited range of array indeices with the goal of providing direct access to the keys.

hash方法有个hash函数用来给key计算一个hash值,作为数组下标,放到该下标对应的槽中。当不同key根据hash函数计算得到的下标相同时,就出现了冲突。解决冲突有很多方式,比如让每个槽成为链表,每次冲突以后放到该槽链表的尾部,但是查询时间就会退化,不再是O(1)。还有一种探查方式,当key的槽冲突时候,就会根据一种计算方式去寻找下一个空的槽存放,探查方式有线性探查,二次方探查法等,cpython解释器使用的是二次方探查法。还有一个问题就是当python使用的槽数量大于预分配的2/3时候,会重新分配内存并拷贝以前的数据,所以有时候dict的add操作代价还是比较高的,牺牲空间但是可以始终保证O(1)的查询效率。如果有大量的数据,建议还是使用bloomfilter或者redis提供的HyperLogLog。

如果你感兴趣,可以看看这篇文章,介绍c解释器如何实现的python dict对象:Python dictionary implementation。我们使用Python来实现一个类似的hash结构。

  1. import ctypes
  2. class Array: # 第二章曾经定义过的ADT,这里当做HashMap的槽数组使用
  3. def __init__(self, size):
  4. assert size > 0, 'array size must be > 0'
  5. self._size = size
  6. PyArrayType = ctypes.py_object * size
  7. self._elements = PyArrayType()
  8. self.clear(None)
  9. def __len__(self):
  10. return self._size
  11. def __getitem__(self, index):
  12. assert index >= 0 and index < len(self), 'out of range'
  13. return self._elements[index]
  14. def __setitem__(self, index, value):
  15. assert index >= 0 and index < len(self), 'out of range'
  16. self._elements[index] = value
  17. def clear(self, value):
  18. """ 设置每个元素为value """
  19. for i in range(len(self)):
  20. self._elements[i] = value
  21. def __iter__(self):
  22. return _ArrayIterator(self._elements)
  23. class _ArrayIterator:
  24. def __init__(self, items):
  25. self._items = items
  26. self._idx = 0
  27. def __iter__(self):
  28. return self
  29. def __next__(self):
  30. if self._idx < len(self._items):
  31. val = self._items[self._idx]
  32. self._idx += 1
  33. return val
  34. else:
  35. raise StopIteration
  36. class HashMap:
  37. """ HashMap ADT实现,类似于python内置的dict
  38. 一个槽有三种状态:
  39. 1.从未使用 HashMap.UNUSED。此槽没有被使用和冲突过,查找时只要找到UNUSEd就不用再继续探查了
  40. 2.使用过但是remove了,此时是 HashMap.EMPTY,该探查点后边的元素扔可能是有key
  41. 3.槽正在使用 _MapEntry节点
  42. """
  43. class _MapEntry: # 槽里存储的数据
  44. def __init__(self, key, value):
  45. self.key = key
  46. self.value = value
  47. UNUSED = None # 没被使用过的槽,作为该类变量的一个单例,下边都是is 判断
  48. EMPTY = _MapEntry(None, None) # 使用过但是被删除的槽
  49. def __init__(self):
  50. self._table = Array(7) # 初始化7个槽
  51. self._count = 0
  52. # 超过2/3空间被使用就重新分配,load factor = 2/3
  53. self._maxCount = len(self._table) - len(self._table) // 3
  54. def __len__(self):
  55. return self._count
  56. def __contains__(self, key):
  57. slot = self._findSlot(key, False)
  58. return slot is not None
  59. def add(self, key, value):
  60. if key in self: # 覆盖原有value
  61. slot = self._findSlot(key, False)
  62. self._table[slot].value = value
  63. return False
  64. else:
  65. slot = self._findSlot(key, True)
  66. self._table[slot] = HashMap._MapEntry(key, value)
  67. self._count += 1
  68. if self._count == self._maxCount: # 超过2/3使用就rehash
  69. self._rehash()
  70. return True
  71. def valueOf(self, key):
  72. slot = self._findSlot(key, False)
  73. assert slot is not None, 'Invalid map key'
  74. return self._table[slot].value
  75. def remove(self, key):
  76. """ remove操作把槽置为EMPTY"""
  77. assert key in self, 'Key error %s' % key
  78. slot = self._findSlot(key, forInsert=False)
  79. value = self._table[slot].value
  80. self._count -= 1
  81. self._table[slot] = HashMap.EMPTY
  82. return value
  83. def __iter__(self):
  84. return _HashMapIteraotr(self._table)
  85. def _slot_can_insert(self, slot):
  86. return (self._table[slot] is HashMap.EMPTY or
  87. self._table[slot] is HashMap.UNUSED)
  88. def _findSlot(self, key, forInsert=False):
  89. """ 注意原书有错误,代码根本不能运行,这里我自己改写的
  90. Args:
  91. forInsert (bool): if the search is for an insertion
  92. Returns:
  93. slot or None
  94. """
  95. slot = self._hash1(key)
  96. step = self._hash2(key)
  97. _len = len(self._table)
  98. if not forInsert: # 查找是否存在key
  99. while self._table[slot] is not HashMap.UNUSED:
  100. # 如果一个槽是UNUSED,直接跳出
  101. if self._table[slot] is HashMap.EMPTY:
  102. slot = (slot + step) % _len
  103. continue
  104. elif self._table[slot].key == key:
  105. return slot
  106. slot = (slot + step) % _len
  107. return None
  108. else: # 为了插入key
  109. while not self._slot_can_insert(slot): # 循环直到找到一个可以插入的槽
  110. slot = (slot + step) % _len
  111. return slot
  112. def _rehash(self): # 当前使用槽数量大于2/3时候重新创建新的table
  113. origTable = self._table
  114. newSize = len(self._table) * 2 + 1 # 原来的2*n+1倍
  115. self._table = Array(newSize)
  116. self._count = 0
  117. self._maxCount = newSize - newSize // 3
  118. # 将原来的key value添加到新的table
  119. for entry in origTable:
  120. if entry is not HashMap.UNUSED and entry is not HashMap.EMPTY:
  121. slot = self._findSlot(entry.key, True)
  122. self._table[slot] = entry
  123. self._count += 1
  124. def _hash1(self, key):
  125. """ 计算key的hash值"""
  126. return abs(hash(key)) % len(self._table)
  127. def _hash2(self, key):
  128. """ key冲突时候用来计算新槽的位置"""
  129. return 1 + abs(hash(key)) % (len(self._table)-2)
  130. class _HashMapIteraotr:
  131. def __init__(self, array):
  132. self._array = array
  133. self._idx = 0
  134. def __iter__(self):
  135. return self
  136. def __next__(self):
  137. if self._idx < len(self._array):
  138. if self._array[self._idx] is not None and self._array[self._idx].key is not None:
  139. key = self._array[self._idx].key
  140. self._idx += 1
  141. return key
  142. else:
  143. self._idx += 1
  144. else:
  145. raise StopIteration
  146. def print_h(h):
  147. for idx, i in enumerate(h):
  148. print(idx, i)
  149. print('\n')
  150. def test_HashMap():
  151. """ 一些简单的单元测试,不过测试用例覆盖不是很全面 """
  152. h = HashMap()
  153. assert len(h) == 0
  154. h.add('a', 'a')
  155. assert h.valueOf('a') == 'a'
  156. assert len(h) == 1
  157. a_v = h.remove('a')
  158. assert a_v == 'a'
  159. assert len(h) == 0
  160. h.add('a', 'a')
  161. h.add('b', 'b')
  162. assert len(h) == 2
  163. assert h.valueOf('b') == 'b'
  164. b_v = h.remove('b')
  165. assert b_v == 'b'
  166. assert len(h) == 1
  167. h.remove('a')
  168. assert len(h) == 0
  169. n = 10
  170. for i in range(n):
  171. h.add(str(i), i)
  172. assert len(h) == n
  173. print_h(h)
  174. for i in range(n):
  175. assert str(i) in h
  176. for i in range(n):
  177. h.remove(str(i))
  178. assert len(h) == 0

12章: Advanced Sorting

第5章介绍了基本的排序算法,本章介绍高级排序算法。

归并排序(mergesort): 分治法

  1. def merge_sorted_list(listA, listB):
  2. """ 归并两个有序数组,O(max(m, n)) ,m和n是数组长度"""
  3. print('merge left right list', listA, listB, end='')
  4. new_list = list()
  5. a = b = 0
  6. while a < len(listA) and b < len(listB):
  7. if listA[a] < listB[b]:
  8. new_list.append(listA[a])
  9. a += 1
  10. else:
  11. new_list.append(listB[b])
  12. b += 1
  13. while a < len(listA):
  14. new_list.append(listA[a])
  15. a += 1
  16. while b < len(listB):
  17. new_list.append(listB[b])
  18. b += 1
  19. print(' ->', new_list)
  20. return new_list
  21. def mergesort(theList):
  22. """ O(nlogn),log层调用,每层n次操作
  23. mergesort: divided and conquer 分治
  24. 1. 把原数组分解成越来越小的子数组
  25. 2. 合并子数组来创建一个有序数组
  26. """
  27. print(theList) # 我把关键步骤打出来了,你可以运行下看看整个过程
  28. if len(theList) <= 1: # 递归出口
  29. return theList
  30. else:
  31. mid = len(theList) // 2
  32. # 递归分解左右两边数组
  33. left_half = mergesort(theList[:mid])
  34. right_half = mergesort(theList[mid:])
  35. # 合并两边的有序子数组
  36. newList = merge_sorted_list(left_half, right_half)
  37. return newList
  38. """ 这是我调用一次打出来的排序过程
  39. [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
  40. [10, 9, 8, 7, 6]
  41. [10, 9]
  42. [10]
  43. [9]
  44. merge left right list [10] [9] -> [9, 10]
  45. [8, 7, 6]
  46. [8]
  47. [7, 6]
  48. [7]
  49. [6]
  50. merge left right list [7] [6] -> [6, 7]
  51. merge left right list [8] [6, 7] -> [6, 7, 8]
  52. merge left right list [9, 10] [6, 7, 8] -> [6, 7, 8, 9, 10]
  53. [5, 4, 3, 2, 1]
  54. [5, 4]
  55. [5]
  56. [4]
  57. merge left right list [5] [4] -> [4, 5]
  58. [3, 2, 1]
  59. [3]
  60. [2, 1]
  61. [2]
  62. [1]
  63. merge left right list [2] [1] -> [1, 2]
  64. merge left right list [3] [1, 2] -> [1, 2, 3]
  65. merge left right list [4, 5] [1, 2, 3] -> [1, 2, 3, 4, 5]
  66. """

快速排序

  1. def quicksort(theSeq, first, last): # average: O(nlog(n))
  2. """
  3. quicksort :也是分而治之,但是和归并排序不同的是,采用选定主元(pivot)而不是从中间
  4. 进行数组划分
  5. 1. 第一步选定pivot用来划分数组,pivot左边元素都比它小,右边元素都大于等于它
  6. 2. 对划分的左右两边数组递归,直到递归出口(数组元素数目小于2)
  7. 3. 对pivot和左右划分的数组合并成一个有序数组
  8. """
  9. if first < last:
  10. pos = partitionSeq(theSeq, first, last)
  11. # 对划分的子数组递归操作
  12. quicksort(theSeq, first, pos - 1)
  13. quicksort(theSeq, pos + 1, last)
  14. def partitionSeq(theSeq, first, last):
  15. """ 快排中的划分操作,把比pivot小的挪到左边,比pivot大的挪到右边"""
  16. pivot = theSeq[first]
  17. print('before partitionSeq', theSeq)
  18. left = first + 1
  19. right = last
  20. while True:
  21. # 找到第一个比pivot大的
  22. while left <= right and theSeq[left] < pivot:
  23. left += 1
  24. # 从右边开始找到比pivot小的
  25. while right >= left and theSeq[right] >= pivot:
  26. right -= 1
  27. if right < left:
  28. break
  29. else:
  30. theSeq[left], theSeq[right] = theSeq[right], theSeq[left]
  31. # 把pivot放到合适的位置
  32. theSeq[first], theSeq[right] = theSeq[right], theSeq[first]
  33. print('after partitionSeq {}: {}\t'.format(theSeq, pivot))
  34. return right # 返回pivot的位置
  35. def test_partitionSeq():
  36. l = [0,1,2,3,4]
  37. assert partitionSeq(l, 0, len(l)-1) == 0
  38. l = [4,3,2,1,0]
  39. assert partitionSeq(l, 0, len(l)-1) == 4
  40. l = [2,3,0,1,4]
  41. assert partitionSeq(l, 0, len(l)-1) == 2
  42. test_partitionSeq()
  43. def test_quicksort():
  44. def _is_sorted(seq):
  45. for i in range(len(seq)-1):
  46. if seq[i] > seq[i+1]:
  47. return False
  48. return True
  49. from random import randint
  50. for i in range(100):
  51. _len = randint(1, 100)
  52. to_sort = []
  53. for i in range(_len):
  54. to_sort.append(randint(0, 100))
  55. quicksort(to_sort, 0, len(to_sort)-1) # 注意这里用了原地排序,直接更改了数组
  56. print(to_sort)
  57. assert _is_sorted(to_sort)
  58. test_quicksort()

利用快排中的partitionSeq操作,我们还能实现另一个算法,nth_element,快速查找一个无序数组中的第k大元素

  1. def nth_element(seq, beg, end, k):
  2. if beg == end:
  3. return seq[beg]
  4. pivot_index = partitionSeq(seq, beg, end)
  5. if pivot_index == k:
  6. return seq[k]
  7. elif pivot_index > k:
  8. return nth_element(seq, beg, pivot_index-1, k)
  9. else:
  10. return nth_element(seq, pivot_index+1, end, k)
  11. def test_nth_element():
  12. from random import shuffle
  13. n = 10
  14. l = list(range(n))
  15. shuffle(l)
  16. print(l)
  17. for i in range(len(l)):
  18. assert nth_element(l, 0, len(l)-1, i) == i
  19. test_nth_element()

13章: Binary Tree

The binary Tree: 二叉树,每个节点做多只有两个子节点

  1. class _BinTreeNode:
  2. def __init__(self, data):
  3. self.data = data
  4. self.left = None
  5. self.right = None
  6. # 三种depth-first遍历
  7. def preorderTrav(subtree):
  8. """ 先(根)序遍历"""
  9. if subtree is not None:
  10. print(subtree.data)
  11. preorderTrav(subtree.left)
  12. preorderTrav(subtree.right)
  13. def inorderTrav(subtree):
  14. """ 中(根)序遍历"""
  15. if subtree is not None:
  16. preorderTrav(subtree.left)
  17. print(subtree.data)
  18. preorderTrav(subtree.right)
  19. def postorderTrav(subtree):
  20. """ 后(根)序遍历"""
  21. if subtree is not None:
  22. preorderTrav(subtree.left)
  23. preorderTrav(subtree.right)
  24. print(subtree.data)
  25. # 宽度优先遍历(bradth-First Traversal): 一层一层遍历, 使用queue
  26. def breadthFirstTrav(bintree):
  27. from queue import Queue # py3
  28. q = Queue()
  29. q.put(bintree)
  30. while not q.empty():
  31. node = q.get()
  32. print(node.data)
  33. if node.left is not None:
  34. q.put(node.left)
  35. if node.right is not None:
  36. q.put(node.right)
  37. class _ExpTreeNode:
  38. __slots__ = ('element', 'left', 'right')
  39. def __init__(self, data):
  40. self.element = data
  41. self.left = None
  42. self.right = None
  43. def __repr__(self):
  44. return '<_ExpTreeNode: {} {} {}>'.format(
  45. self.element, self.left, self.right)
  46. from queue import Queue
  47. class ExpressionTree:
  48. """
  49. 表达式树: 操作符存储在内节点操作数存储在叶子节点的二叉树。(符号树真难打出来)
  50. *
  51. / \
  52. + -
  53. / \ / \
  54. 9 3 8 4
  55. (9+3) * (8-4)
  56. Expression Tree Abstract Data Type,可以实现二元操作符
  57. ExpressionTree(expStr): user string as constructor param
  58. evaluate(varDict): evaluates the expression and returns the numeric result
  59. toString(): constructs and retutns a string represention of the expression
  60. Usage:
  61. vars = {'a': 5, 'b': 12}
  62. expTree = ExpressionTree("(a/(b-3))")
  63. print('The result = ', expTree.evaluate(vars))
  64. """
  65. def __init__(self, expStr):
  66. self._expTree = None
  67. self._buildTree(expStr)
  68. def evaluate(self, varDict):
  69. return self._evalTree(self._expTree, varDict)
  70. def __str__(self):
  71. return self._buildString(self._expTree)
  72. def _buildString(self, treeNode):
  73. """ 在一个子树被遍历之前添加做括号,在子树被遍历之后添加右括号 """
  74. # print(treeNode)
  75. if treeNode.left is None and treeNode.right is None:
  76. return str(treeNode.element) # 叶子节点是操作数直接返回
  77. else:
  78. expStr = '('
  79. expStr += self._buildString(treeNode.left)
  80. expStr += str(treeNode.element)
  81. expStr += self._buildString(treeNode.right)
  82. expStr += ')'
  83. return expStr
  84. def _evalTree(self, subtree, varDict):
  85. # 是不是叶子节点, 是的话说明是操作数,直接返回
  86. if subtree.left is None and subtree.right is None:
  87. # 操作数是合法数字吗
  88. if subtree.element >= '0' and subtree.element <= '9':
  89. return int(subtree.element)
  90. else: # 操作数是个变量
  91. assert subtree.element in varDict, 'invalid variable.'
  92. return varDict[subtree.element]
  93. else: # 操作符则计算其子表达式
  94. lvalue = self._evalTree(subtree.left, varDict)
  95. rvalue = self._evalTree(subtree.right, varDict)
  96. print(subtree.element)
  97. return self._computeOp(lvalue, subtree.element, rvalue)
  98. def _computeOp(self, left, op, right):
  99. assert op
  100. op_func = {
  101. '+': lambda left, right: left + right, # or import operator, operator.add
  102. '-': lambda left, right: left - right,
  103. '*': lambda left, right: left * right,
  104. '/': lambda left, right: left / right,
  105. '%': lambda left, right: left % right,
  106. }
  107. return op_func[op](left, right)
  108. def _buildTree(self, expStr):
  109. expQ = Queue()
  110. for token in expStr: # 遍历表达式字符串的每个字符
  111. expQ.put(token)
  112. self._expTree = _ExpTreeNode(None) # 创建root节点
  113. self._recBuildTree(self._expTree, expQ)
  114. def _recBuildTree(self, curNode, expQ):
  115. token = expQ.get()
  116. if token == '(':
  117. curNode.left = _ExpTreeNode(None)
  118. self._recBuildTree(curNode.left, expQ)
  119. # next token will be an operator: + = * / %
  120. curNode.element = expQ.get()
  121. curNode.right = _ExpTreeNode(None)
  122. self._recBuildTree(curNode.right, expQ)
  123. # the next token will be ')', remmove it
  124. expQ.get()
  125. else: # the token is a digit that has to be converted to an int.
  126. curNode.element = token
  127. vars = {'a': 5, 'b': 12}
  128. expTree = ExpressionTree("((2*7)+8)")
  129. print(expTree)
  130. print('The result = ', expTree.evaluate(vars))

Heap(堆):二叉树最直接的一个应用就是实现堆。堆就是一颗完全二叉树,最大堆的非叶子节点的值都比孩子大,最小堆的非叶子结点的值都比孩子小。 python内置了heapq模块帮助我们实现堆操作,比如用内置的heapq模块实现个堆排序:

  1. # 使用python内置的heapq实现heap sort
  2. def heapsort(iterable):
  3. from heapq import heappush, heappop
  4. h = []
  5. for value in iterable:
  6. heappush(h, value)
  7. return [heappop(h) for i in range(len(h))]

但是一般实现堆的时候实际上并不是用数节点来实现的,而是使用数组实现,效率比较高。为什么可以用数组实现呢?因为完全二叉树的性质, 可以用下标之间的关系表示节点之间的关系,MaxHeap的docstring中已经说明了

  1. class MaxHeap:
  2. """
  3. Heaps:
  4. 完全二叉树,最大堆的非叶子节点的值都比孩子大,最小堆的非叶子结点的值都比孩子小
  5. Heap包含两个属性,order property 和 shape property(a complete binary tree),在插入
  6. 一个新节点的时候,始终要保持这两个属性
  7. 插入操作:保持堆属性和完全二叉树属性, sift-up 操作维持堆属性
  8. extract操作:只获取根节点数据,并把树最底层最右节点copy到根节点后,sift-down操作维持堆属性
  9. 用数组实现heap,从根节点开始,从上往下从左到右给每个节点编号,则根据完全二叉树的
  10. 性质,给定一个节点i, 其父亲和孩子节点的编号分别是:
  11. parent = (i-1) // 2
  12. left = 2 * i + 1
  13. rgiht = 2 * i + 2
  14. 使用数组实现堆一方面效率更高,节省树节点的内存占用,一方面还可以避免复杂的指针操作,减少
  15. 调试难度。
  16. """
  17. def __init__(self, maxSize):
  18. self._elements = Array(maxSize) # 第二章实现的Array ADT
  19. self._count = 0
  20. def __len__(self):
  21. return self._count
  22. def capacity(self):
  23. return len(self._elements)
  24. def add(self, value):
  25. assert self._count < self.capacity(), 'can not add to full heap'
  26. self._elements[self._count] = value
  27. self._count += 1
  28. self._siftUp(self._count - 1)
  29. self.assert_keep_heap() # 确定每一步add操作都保持堆属性
  30. def extract(self):
  31. assert self._count > 0, 'can not extract from an empty heap'
  32. value = self._elements[0] # save root value
  33. self._count -= 1
  34. self._elements[0] = self._elements[self._count] # 最右下的节点放到root后siftDown
  35. self._siftDown(0)
  36. self.assert_keep_heap()
  37. return value
  38. def _siftUp(self, ndx):
  39. if ndx > 0:
  40. parent = (ndx - 1) // 2
  41. # print(ndx, parent)
  42. if self._elements[ndx] > self._elements[parent]: # swap
  43. self._elements[ndx], self._elements[parent] = self._elements[parent], self._elements[ndx]
  44. self._siftUp(parent) # 递归
  45. def _siftDown(self, ndx):
  46. left = 2 * ndx + 1
  47. right = 2 * ndx + 2
  48. # determine which node contains the larger value
  49. largest = ndx
  50. if (left < self._count and
  51. self._elements[left] >= self._elements[largest] and
  52. self._elements[left] >= self._elements[right]): # 原书这个地方没写实际上找的未必是largest
  53. largest = left
  54. elif right < self._count and self._elements[right] >= self._elements[largest]:
  55. largest = right
  56. if largest != ndx:
  57. self._elements[ndx], self._elements[largest] = self._elements[largest], self._elements[ndx]
  58. self._siftDown(largest)
  59. def __repr__(self):
  60. return ' '.join(map(str, self._elements))
  61. def assert_keep_heap(self):
  62. """ 我加了这个函数是用来验证每次add或者extract之后,仍保持最大堆的性质"""
  63. _len = len(self)
  64. for i in range(0, int((_len-1)/2)): # 内部节点(非叶子结点)
  65. l = 2 * i + 1
  66. r = 2 * i + 2
  67. if l < _len and r < _len:
  68. assert self._elements[i] >= self._elements[l] and self._elements[i] >= self._elements[r]
  69. def test_MaxHeap():
  70. """ 最大堆实现的单元测试用例 """
  71. _len = 10
  72. h = MaxHeap(_len)
  73. for i in range(_len):
  74. h.add(i)
  75. h.assert_keep_heap()
  76. for i in range(_len):
  77. # 确定每次出来的都是最大的数字,添加的时候是从小到大添加的
  78. assert h.extract() == _len-i-1
  79. test_MaxHeap()
  80. def simpleHeapSort(theSeq):
  81. """ 用自己实现的MaxHeap实现堆排序,直接修改原数组实现inplace排序"""
  82. if not theSeq:
  83. return theSeq
  84. _len = len(theSeq)
  85. heap = MaxHeap(_len)
  86. for i in theSeq:
  87. heap.add(i)
  88. for i in reversed(range(_len)):
  89. theSeq[i] = heap.extract()
  90. return theSeq
  91. def test_simpleHeapSort():
  92. """ 用一些测试用例证明实现的堆排序是可以工作的 """
  93. def _is_sorted(seq):
  94. for i in range(len(seq)-1):
  95. if seq[i] > seq[i+1]:
  96. return False
  97. return True
  98. from random import randint
  99. assert simpleHeapSort([]) == []
  100. for i in range(1000):
  101. _len = randint(1, 100)
  102. to_sort = []
  103. for i in range(_len):
  104. to_sort.append(randint(0, 100))
  105. simpleHeapSort(to_sort) # 注意这里用了原地排序,直接更改了数组
  106. assert _is_sorted(to_sort)
  107. test_simpleHeapSort()

14章: Search Trees

二叉差找树性质:对每个内部节点V, 1. 所有key小于V.key的存储在V的左子树。 2. 所有key大于V.key的存储在V的右子树 对BST进行中序遍历会得到升序的key序列

  1. class _BSTMapNode:
  2. __slots__ = ('key', 'value', 'left', 'right')
  3. def __init__(self, key, value):
  4. self.key = key
  5. self.value = value
  6. self.left = None
  7. self.right = None
  8. def __repr__(self):
  9. return '<{}:{}> left:{}, right:{}'.format(
  10. self.key, self.value, self.left, self.right)
  11. __str__ = __repr__
  12. class BSTMap:
  13. """ BST,树节点包含key可payload。用BST来实现之前用hash实现过的Map ADT.
  14. 性质:对每个内部节点V,
  15. 1.对于节点V,所有key小于V.key的存储在V的左子树。
  16. 2.所有key大于V.key的存储在V的右子树
  17. 对BST进行中序遍历会得到升序的key序列
  18. """
  19. def __init__(self):
  20. self._root = None
  21. self._size = 0
  22. self._rval = None # 作为remove的返回值
  23. def __len__(self):
  24. return self._size
  25. def __iter__(self):
  26. return _BSTMapIterator(self._root, self._size)
  27. def __contains__(self, key):
  28. return self._bstSearch(self._root, key) is not None
  29. def valueOf(self, key):
  30. node = self._bstSearch(self._root, key)
  31. assert node is not None, 'Invalid map key.'
  32. return node.value
  33. def _bstSearch(self, subtree, target):
  34. if subtree is None: # 递归出口,遍历到树底没有找到key或是空树
  35. return None
  36. elif target < subtree.key:
  37. return self._bstSearch(subtree.left, target)
  38. elif target > subtree.key:
  39. return self._bstSearch(subtree.right, target)
  40. return subtree # 返回引用
  41. def _bstMinumum(self, subtree):
  42. """ 顺着树一直往左下角递归找就是最小的,向右下角递归就是最大的 """
  43. if subtree is None:
  44. return None
  45. elif subtree.left is None:
  46. return subtree
  47. else:
  48. return subtree._bstMinumum(self, subtree.left)
  49. def add(self, key, value):
  50. """ 添加或者替代一个key的value, O(N) """
  51. node = self._bstSearch(self._root, key)
  52. if node is not None: # if key already exists, update value
  53. node.value = value
  54. return False
  55. else: # insert a new entry
  56. self._root = self._bstInsert(self._root, key, value)
  57. self._size += 1
  58. return True
  59. def _bstInsert(self, subtree, key, value):
  60. """ 新的节点总是插入在树的叶子结点上 """
  61. if subtree is None:
  62. subtree = _BSTMapNode(key, value)
  63. elif key < subtree.key:
  64. subtree.left = self._bstInsert(subtree.left, key, value)
  65. elif key > subtree.key:
  66. subtree.right = self._bstInsert(subtree.right, key, value)
  67. # 注意这里没有else语句了,应为在被调用处add函数里先判断了是否有重复key
  68. return subtree
  69. def remove(self, key):
  70. """ O(N)
  71. 被删除的节点分为三种:
  72. 1.叶子结点:直接把其父亲指向该节点的指针置None
  73. 2.该节点有一个孩子: 删除该节点后,父亲指向一个合适的该节点的孩子
  74. 3.该节点有俩孩子:
  75. (1)找到要删除节点N和其后继S(中序遍历后该节点下一个)
  76. (2)复制S的key到N
  77. (3)从N的右子树中删除后继S(即在N的右子树中最小的)
  78. """
  79. assert key in self, 'invalid map key'
  80. self._root = self._bstRemove(self._root, key)
  81. self._size -= 1
  82. return self._rval
  83. def _bstRemove(self, subtree, target):
  84. # search for the item in the tree
  85. if subtree is None:
  86. return subtree
  87. elif target < subtree.key:
  88. subtree.left = self._bstRemove(subtree.left, target)
  89. return subtree
  90. elif target > subtree.key:
  91. subtree.right = self._bstRemove(subtree.right, target)
  92. return subtree
  93. else: # found the node containing the item
  94. self._rval = subtree.value
  95. if subtree.left is None and subtree.right is None:
  96. # 叶子node
  97. return None
  98. elif subtree.left is None or subtree.right is None:
  99. # 有一个孩子节点
  100. if subtree.left is not None:
  101. return subtree.left
  102. else:
  103. return subtree.right
  104. else: # 有俩孩子节点
  105. successor = self._bstMinumum(subtree.right)
  106. subtree.key = successor.key
  107. subtree.value = successor.value
  108. subtree.right = self._bstRemove(subtree.right, successor.key)
  109. return subtree
  110. def __repr__(self):
  111. return '->'.join([str(i) for i in self])
  112. def assert_keep_bst_property(self, subtree):
  113. """ 写这个函数为了验证add和delete操作始终维持了bst的性质 """
  114. if subtree is None:
  115. return
  116. if subtree.left is not None and subtree.right is not None:
  117. assert subtree.left.value <= subtree.value
  118. assert subtree.right.value >= subtree.value
  119. self.assert_keep_bst_property(subtree.left)
  120. self.assert_keep_bst_property(subtree.right)
  121. elif subtree.left is None and subtree.right is not None:
  122. assert subtree.right.value >= subtree.value
  123. self.assert_keep_bst_property(subtree.right)
  124. elif subtree.left is not None and subtree.right is None:
  125. assert subtree.left.value <= subtree.value
  126. self.assert_keep_bst_property(subtree.left)
  127. class _BSTMapIterator:
  128. def __init__(self, root, size):
  129. self._theKeys = Array(size)
  130. self._curItem = 0
  131. self._bstTraversal(root)
  132. self._curItem = 0
  133. def __iter__(self):
  134. return self
  135. def __next__(self):
  136. if self._curItem < len(self._theKeys):
  137. key = self._theKeys[self._curItem]
  138. self._curItem += 1
  139. return key
  140. else:
  141. raise StopIteration
  142. def _bstTraversal(self, subtree):
  143. if subtree is not None:
  144. self._bstTraversal(subtree.left)
  145. self._theKeys[self._curItem] = subtree.key
  146. self._curItem += 1
  147. self._bstTraversal(subtree.right)
  148. def test_BSTMap():
  149. l = [60, 25, 100, 35, 17, 80]
  150. bst = BSTMap()
  151. for i in l:
  152. bst.add(i)
  153. def test_HashMap():
  154. """ 之前用来测试用hash实现的map,改为用BST实现的Map测试 """
  155. # h = HashMap()
  156. h = BSTMap()
  157. assert len(h) == 0
  158. h.add('a', 'a')
  159. assert h.valueOf('a') == 'a'
  160. assert len(h) == 1
  161. a_v = h.remove('a')
  162. assert a_v == 'a'
  163. assert len(h) == 0
  164. h.add('a', 'a')
  165. h.add('b', 'b')
  166. assert len(h) == 2
  167. assert h.valueOf('b') == 'b'
  168. b_v = h.remove('b')
  169. assert b_v == 'b'
  170. assert len(h) == 1
  171. h.remove('a')
  172. assert len(h) == 0
  173. _len = 10
  174. for i in range(_len):
  175. h.add(str(i), i)
  176. assert len(h) == _len
  177. for i in range(_len):
  178. assert str(i) in h
  179. for i in range(_len):
  180. print(len(h))
  181. print('bef', h)
  182. _ = h.remove(str(i))
  183. assert _ == i
  184. print('aft', h)
  185. print(len(h))
  186. assert len(h) == 0
  187. test_HashMap()