Source code for examples.postgis.postgis

  1. import binascii
  2. from sqlalchemy import event
  3. from sqlalchemy import Table
  4. from sqlalchemy.sql import expression
  5. from sqlalchemy.sql import type_coerce
  6. from sqlalchemy.types import UserDefinedType
  7. # Python datatypes
  8. class GisElement(object):
  9. """Represents a geometry value."""
  10. def __str__(self):
  11. return self.desc
  12. def __repr__(self):
  13. return "<%s at 0x%x; %r>" % (
  14. self.__class__.__name__,
  15. id(self),
  16. self.desc,
  17. )
  18. class BinaryGisElement(GisElement, expression.Function):
  19. """Represents a Geometry value expressed as binary."""
  20. def __init__(self, data):
  21. self.data = data
  22. expression.Function.__init__(
  23. self, "ST_GeomFromEWKB", data, type_=Geometry(coerce_="binary")
  24. )
  25. @property
  26. def desc(self):
  27. return self.as_hex
  28. @property
  29. def as_hex(self):
  30. return binascii.hexlify(self.data)
  31. class TextualGisElement(GisElement, expression.Function):
  32. """Represents a Geometry value expressed as text."""
  33. def __init__(self, desc, srid=-1):
  34. self.desc = desc
  35. expression.Function.__init__(
  36. self, "ST_GeomFromText", desc, srid, type_=Geometry
  37. )
  38. # SQL datatypes.
  39. class Geometry(UserDefinedType):
  40. """Base PostGIS Geometry column type."""
  41. name = "GEOMETRY"
  42. def __init__(self, dimension=None, srid=-1, coerce_="text"):
  43. self.dimension = dimension
  44. self.srid = srid
  45. self.coerce = coerce_
  46. class comparator_factory(UserDefinedType.Comparator):
  47. """Define custom operations for geometry types."""
  48. # override the __eq__() operator
  49. def __eq__(self, other):
  50. return self.op("~=")(other)
  51. # add a custom operator
  52. def intersects(self, other):
  53. return self.op("&&")(other)
  54. # any number of GIS operators can be overridden/added here
  55. # using the techniques above.
  56. def _coerce_compared_value(self, op, value):
  57. return self
  58. def get_col_spec(self):
  59. return self.name
  60. def bind_expression(self, bindvalue):
  61. if self.coerce == "text":
  62. return TextualGisElement(bindvalue)
  63. elif self.coerce == "binary":
  64. return BinaryGisElement(bindvalue)
  65. else:
  66. assert False
  67. def column_expression(self, col):
  68. if self.coerce == "text":
  69. return func.ST_AsText(col, type_=self)
  70. elif self.coerce == "binary":
  71. return func.ST_AsBinary(col, type_=self)
  72. else:
  73. assert False
  74. def bind_processor(self, dialect):
  75. def process(value):
  76. if isinstance(value, GisElement):
  77. return value.desc
  78. else:
  79. return value
  80. return process
  81. def result_processor(self, dialect, coltype):
  82. if self.coerce == "text":
  83. fac = TextualGisElement
  84. elif self.coerce == "binary":
  85. fac = BinaryGisElement
  86. else:
  87. assert False
  88. def process(value):
  89. if value is not None:
  90. return fac(value)
  91. else:
  92. return value
  93. return process
  94. def adapt(self, impltype):
  95. return impltype(
  96. dimension=self.dimension, srid=self.srid, coerce_=self.coerce
  97. )
  98. # other datatypes can be added as needed.
  99. class Point(Geometry):
  100. name = "POINT"
  101. class Curve(Geometry):
  102. name = "CURVE"
  103. class LineString(Curve):
  104. name = "LINESTRING"
  105. # ... etc.
  106. # DDL integration
  107. # PostGIS historically has required AddGeometryColumn/DropGeometryColumn
  108. # and other management methods in order to create PostGIS columns. Newer
  109. # versions don't appear to require these special steps anymore. However,
  110. # here we illustrate how to set up these features in any case.
  111. def setup_ddl_events():
  112. @event.listens_for(Table, "before_create")
  113. def before_create(target, connection, **kw):
  114. dispatch("before-create", target, connection)
  115. @event.listens_for(Table, "after_create")
  116. def after_create(target, connection, **kw):
  117. dispatch("after-create", target, connection)
  118. @event.listens_for(Table, "before_drop")
  119. def before_drop(target, connection, **kw):
  120. dispatch("before-drop", target, connection)
  121. @event.listens_for(Table, "after_drop")
  122. def after_drop(target, connection, **kw):
  123. dispatch("after-drop", target, connection)
  124. def dispatch(event, table, bind):
  125. if event in ("before-create", "before-drop"):
  126. regular_cols = [
  127. c for c in table.c if not isinstance(c.type, Geometry)
  128. ]
  129. gis_cols = set(table.c).difference(regular_cols)
  130. table.info["_saved_columns"] = table.c
  131. # temporarily patch a set of columns not including the
  132. # Geometry columns
  133. table.columns = expression.ColumnCollection(*regular_cols)
  134. if event == "before-drop":
  135. for c in gis_cols:
  136. bind.execute(
  137. select(
  138. func.DropGeometryColumn(
  139. "public", table.name, c.name
  140. )
  141. ).execution_options(autocommit=True)
  142. )
  143. elif event == "after-create":
  144. table.columns = table.info.pop("_saved_columns")
  145. for c in table.c:
  146. if isinstance(c.type, Geometry):
  147. bind.execute(
  148. select(
  149. func.AddGeometryColumn(
  150. table.name,
  151. c.name,
  152. c.type.srid,
  153. c.type.name,
  154. c.type.dimension,
  155. )
  156. ).execution_options(autocommit=True)
  157. )
  158. elif event == "after-drop":
  159. table.columns = table.info.pop("_saved_columns")
  160. setup_ddl_events()
  161. # illustrate usage
  162. if __name__ == "__main__":
  163. from sqlalchemy import (
  164. create_engine,
  165. MetaData,
  166. Column,
  167. Integer,
  168. String,
  169. func,
  170. select,
  171. )
  172. from sqlalchemy.orm import sessionmaker
  173. from sqlalchemy.ext.declarative import declarative_base
  174. engine = create_engine(
  175. "postgresql://scott:tiger@localhost/test", echo=True
  176. )
  177. metadata = MetaData(engine)
  178. Base = declarative_base(metadata=metadata)
  179. class Road(Base):
  180. __tablename__ = "roads"
  181. road_id = Column(Integer, primary_key=True)
  182. road_name = Column(String)
  183. road_geom = Column(Geometry(2))
  184. metadata.drop_all()
  185. metadata.create_all()
  186. session = sessionmaker(bind=engine)()
  187. # Add objects. We can use strings...
  188. session.add_all(
  189. [
  190. Road(
  191. road_name="Jeff Rd",
  192. road_geom="LINESTRING(191232 243118,191108 243242)",
  193. ),
  194. Road(
  195. road_name="Geordie Rd",
  196. road_geom="LINESTRING(189141 244158,189265 244817)",
  197. ),
  198. Road(
  199. road_name="Paul St",
  200. road_geom="LINESTRING(192783 228138,192612 229814)",
  201. ),
  202. Road(
  203. road_name="Graeme Ave",
  204. road_geom="LINESTRING(189412 252431,189631 259122)",
  205. ),
  206. Road(
  207. road_name="Phil Tce",
  208. road_geom="LINESTRING(190131 224148,190871 228134)",
  209. ),
  210. ]
  211. )
  212. # or use an explicit TextualGisElement
  213. # (similar to saying func.GeomFromText())
  214. r = Road(
  215. road_name="Dave Cres",
  216. road_geom=TextualGisElement(
  217. "LINESTRING(198231 263418,198213 268322)", -1
  218. ),
  219. )
  220. session.add(r)
  221. # pre flush, the TextualGisElement represents the string we sent.
  222. assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)"
  223. session.commit()
  224. # after flush and/or commit, all the TextualGisElements
  225. # become PersistentGisElements.
  226. assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)"
  227. r1 = session.query(Road).filter(Road.road_name == "Graeme Ave").one()
  228. # illustrate the overridden __eq__() operator.
  229. # strings come in as TextualGisElements
  230. r2 = (
  231. session.query(Road)
  232. .filter(Road.road_geom == "LINESTRING(189412 252431,189631 259122)")
  233. .one()
  234. )
  235. r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one()
  236. assert r1 is r2 is r3
  237. # core usage just fine:
  238. road_table = Road.__table__
  239. stmt = select(road_table).where(
  240. road_table.c.road_geom.intersects(r1.road_geom)
  241. )
  242. print(session.execute(stmt).fetchall())
  243. # TODO: for some reason the auto-generated labels have the internal
  244. # replacement strings exposed, even though PG doesn't complain
  245. # look up the hex binary version, using SQLAlchemy casts
  246. as_binary = session.scalar(
  247. select(type_coerce(r.road_geom, Geometry(coerce_="binary")))
  248. )
  249. assert as_binary.as_hex == (
  250. "01020000000200000000000000b832084100000000"
  251. "e813104100000000283208410000000088601041"
  252. )
  253. # back again, same method !
  254. as_text = session.scalar(
  255. select(type_coerce(as_binary, Geometry(coerce_="text")))
  256. )
  257. assert as_text.desc == "LINESTRING(198231 263418,198213 268322)"
  258. session.rollback()
  259. metadata.drop_all()