Source code for examples.nested_sets.nested_sets

  1. """Celko's "Nested Sets" Tree Structure.
  2. http://www.intelligententerprise.com/001020/celko.jhtml
  3. """
  4. from sqlalchemy import case
  5. from sqlalchemy import Column
  6. from sqlalchemy import create_engine
  7. from sqlalchemy import event
  8. from sqlalchemy import func
  9. from sqlalchemy import Integer
  10. from sqlalchemy import select
  11. from sqlalchemy import String
  12. from sqlalchemy.ext.declarative import declarative_base
  13. from sqlalchemy.orm import aliased
  14. from sqlalchemy.orm import Session
  15. Base = declarative_base()
  16. class Employee(Base):
  17. __tablename__ = "personnel"
  18. __mapper_args__ = {
  19. "batch": False # allows extension to fire for each
  20. # instance before going to the next.
  21. }
  22. parent = None
  23. emp = Column(String, primary_key=True)
  24. left = Column("lft", Integer, nullable=False)
  25. right = Column("rgt", Integer, nullable=False)
  26. def __repr__(self):
  27. return "Employee(%s, %d, %d)" % (self.emp, self.left, self.right)
  28. @event.listens_for(Employee, "before_insert")
  29. def before_insert(mapper, connection, instance):
  30. if not instance.parent:
  31. instance.left = 1
  32. instance.right = 2
  33. else:
  34. personnel = mapper.mapped_table
  35. right_most_sibling = connection.scalar(
  36. select(personnel.c.rgt).where(
  37. personnel.c.emp == instance.parent.emp
  38. )
  39. )
  40. connection.execute(
  41. personnel.update(personnel.c.rgt >= right_most_sibling).values(
  42. lft=case(
  43. [
  44. (
  45. personnel.c.lft > right_most_sibling,
  46. personnel.c.lft + 2,
  47. )
  48. ],
  49. else_=personnel.c.lft,
  50. ),
  51. rgt=case(
  52. [
  53. (
  54. personnel.c.rgt >= right_most_sibling,
  55. personnel.c.rgt + 2,
  56. )
  57. ],
  58. else_=personnel.c.rgt,
  59. ),
  60. )
  61. )
  62. instance.left = right_most_sibling
  63. instance.right = right_most_sibling + 1
  64. # before_update() would be needed to support moving of nodes
  65. # after_delete() would be needed to support removal of nodes.
  66. engine = create_engine("sqlite://", echo=True)
  67. Base.metadata.create_all(engine)
  68. session = Session(bind=engine)
  69. albert = Employee(emp="Albert")
  70. bert = Employee(emp="Bert")
  71. chuck = Employee(emp="Chuck")
  72. donna = Employee(emp="Donna")
  73. eddie = Employee(emp="Eddie")
  74. fred = Employee(emp="Fred")
  75. bert.parent = albert
  76. chuck.parent = albert
  77. donna.parent = chuck
  78. eddie.parent = chuck
  79. fred.parent = chuck
  80. # the order of "add" is important here. elements must be added in
  81. # the order in which they should be INSERTed.
  82. session.add_all([albert, bert, chuck, donna, eddie, fred])
  83. session.commit()
  84. print(session.query(Employee).all())
  85. # 1. Find an employee and all their supervisors, no matter how deep the tree.
  86. ealias = aliased(Employee)
  87. print(
  88. session.query(Employee)
  89. .filter(ealias.left.between(Employee.left, Employee.right))
  90. .filter(ealias.emp == "Eddie")
  91. .all()
  92. )
  93. # 2. Find the employee and all their subordinates.
  94. # (This query has a nice symmetry with the first query.)
  95. print(
  96. session.query(Employee)
  97. .filter(Employee.left.between(ealias.left, ealias.right))
  98. .filter(ealias.emp == "Chuck")
  99. .all()
  100. )
  101. # 3. Find the level of each node, so you can print the tree
  102. # as an indented listing.
  103. for indentation, employee in (
  104. session.query(func.count(Employee.emp).label("indentation") - 1, ealias)
  105. .filter(ealias.left.between(Employee.left, Employee.right))
  106. .group_by(ealias.emp)
  107. .order_by(ealias.left)
  108. ):
  109. print(" " * indentation + str(employee))