Search Similar 3D Meshes

In this tutorial, we will learn how to build a 3D mesh search pipeline with Jina. In particular, we will be building a search pipeline for 3D models in GLB format.

Just like other data types, the 3D meshes search pipeline consists of loading, encoding and indexing the data. We can search the data after they are indexed.

Prerequisites

Let’s first install the following PyPI dependencies:

  1. pip install tensorflow trimesh pyrender

Load GLB data

First, given a glb file, how do we load and craft the glb into a Document so that we can process and encode? Let’s use trimesh to build an executor for this.

  1. def as_mesh(scene: trimesh.Scene) -> Optional[trimesh.Trimesh]:
  2. if len(scene.geometry) == 0:
  3. return None
  4. return trimesh.util.concatenate(
  5. tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
  6. for g in scene.geometry.values()))
  7. class GlbCrafter(Executor):
  8. @requests(on=['/index', '/search'])
  9. def craft(self, docs: DocumentArray, **kwargs):
  10. for d in docs:
  11. mesh = trimesh.load_mesh(d.uri)
  12. d.blob = as_mesh(mesh).sample(2048)

We first load the data of each glb file as Python object. We will use the trimesh package to represents the glb data in the form of triangular meshes. The loaded object is of type trimesh.Scene which may contain one or more triangular mesh geometries. We combine all the meshes in the Scene to create a single Trimesh using as_mesh. Then we can sample surfaces from a single mesh geometry. The sampled surface will be made from 2048 points in 3D space and hence the shape of the ndarray representing each 3D model will be (2048, 3).

Encode 3D Model

Once we convert each glb model into an ndarray, encoding the inputs becomes straightforward. We will use our pre-trained pointnet to encode the data. The model looks like:

  1. def get_model(ckpt_path):
  2. import numpy as np
  3. import tensorflow as tf
  4. from tensorflow import keras
  5. from tensorflow.keras import layers
  6. def conv_bn(x, filters):
  7. x = layers.Conv1D(filters, kernel_size=1, padding='valid')(x)
  8. x = layers.BatchNormalization(momentum=0.0)(x)
  9. return layers.Activation('relu')(x)
  10. def dense_bn(x, filters):
  11. x = layers.Dense(filters)(x)
  12. x = layers.BatchNormalization(momentum=0.0)(x)
  13. return layers.Activation('relu')(x)
  14. def tnet(inputs, num_features):
  15. class OrthogonalRegularizer(keras.regularizers.Regularizer):
  16. def __init__(self, num_features_, l2reg=0.001):
  17. self.num_features = num_features_
  18. self.l2reg = l2reg
  19. self.eye = tf.eye(self.num_features)
  20. def __call__(self, x):
  21. x = tf.reshape(x, (-1, self.num_features, self.num_features))
  22. xxt = tf.tensordot(x, x, axes=(2, 2))
  23. xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
  24. return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))
  25. def get_config(self):
  26. return {'num_features': self.num_features,
  27. 'l2reg': self.l2reg,
  28. 'eye': self.eye.numpy()}
  29. bias = keras.initializers.Constant(np.eye(num_features).flatten())
  30. reg = OrthogonalRegularizer(num_features)
  31. x = conv_bn(inputs, 32)
  32. x = conv_bn(x, 64)
  33. x = conv_bn(x, 512)
  34. x = layers.GlobalMaxPooling1D()(x)
  35. x = dense_bn(x, 256)
  36. x = dense_bn(x, 128)
  37. x = layers.Dense(
  38. num_features * num_features,
  39. kernel_initializer='zeros',
  40. bias_initializer=bias,
  41. activity_regularizer=reg,
  42. )(x)
  43. feat_T = layers.Reshape((num_features, num_features))(x)
  44. return layers.Dot(axes=(2, 1))([inputs, feat_T])
  45. inputs = keras.Input(shape=(2048, 3))
  46. x = tnet(inputs, 3)
  47. x = conv_bn(x, 32)
  48. x = conv_bn(x, 32)
  49. x = tnet(x, 32)
  50. x = conv_bn(x, 32)
  51. x = conv_bn(x, 64)
  52. x = layers.GlobalMaxPooling1D()(x)
  53. x = dense_bn(x, 128)
  54. x = layers.Dropout(0.3)(x)
  55. outputs = layers.Dense(1, activation='softmax')(x)
  56. model = keras.Model(inputs=inputs, outputs=outputs, name='pointnet')
  57. intermediate_layer_model = keras.Model(inputs=model.input,
  58. outputs=model.get_layer(f'dense_1').output)
  59. intermediate_layer_model.load_weights(ckpt_path)
  60. return intermediate_layer_model

With the above model, we can then build our pointnet executor:

  1. class PNEncoder(Executor):
  2. def __init__(self, ckpt_path: str, **kwargs):
  3. super().__init__(**kwargs)
  4. self.embedding_model = get_model(ckpt_path=ckpt_path)
  5. @requests(on=['/index', '/search'])
  6. def encode(self, docs: DocumentArray, **kwargs):
  7. docs.embeddings = self.embedding_model.predict(docs.blobs)

Tips

Instead of iterating over each doc to set its embedding, we can directly get the blobs of all docs in docs at once by using the attribute blobs and set the embeddings of all docs in docs at once by using the attribute embeddings.

Index the data

Let’s also build an indexer to index the data.

  1. class MyIndexer(Executor):
  2. _docs = DocumentArray()
  3. @requests(on='/index')
  4. def index(self, docs: DocumentArray, **kwargs):
  5. self._docs.extend(docs)
  6. @requests(on='/search')
  7. def search(self, docs: DocumentArray, **kwargs):
  8. docs.match(self._docs, limit=5)

The above indexer simply uses DocumentArray to store all the index docs and leverages the match function of DocumentArray to match the query with docs indexed.

Visualize 3D Model

Finally, let’s also build the GlbVisualizer to visualize the results.

  1. import pyrender
  2. import pyglet
  3. from pyglet import clock
  4. from pyglet.gl import Config
  5. from pyrender import Viewer
  6. def _init_and_start_app(self):
  7. TARGET_OPEN_GL_MAJOR = 4 # Target OpenGL Major Version
  8. TARGET_OPEN_GL_MINOR = 1
  9. MIN_OPEN_GL_MAJOR = 3 # Minimum OpenGL Major Version
  10. MIN_OPEN_GL_MINOR = 3 # Minimum OpenGL Minor Version
  11. confs = [Config(sample_buffers=1, samples=4,
  12. depth_size=24,
  13. double_buffer=True,
  14. major_version=TARGET_OPEN_GL_MAJOR,
  15. minor_version=TARGET_OPEN_GL_MINOR),
  16. Config(depth_size=24,
  17. double_buffer=True,
  18. major_version=TARGET_OPEN_GL_MAJOR,
  19. minor_version=TARGET_OPEN_GL_MINOR),
  20. Config(sample_buffers=1, samples=4,
  21. depth_size=24,
  22. double_buffer=True,
  23. major_version=MIN_OPEN_GL_MAJOR,
  24. minor_version=MIN_OPEN_GL_MINOR),
  25. Config(depth_size=24,
  26. double_buffer=True,
  27. major_version=MIN_OPEN_GL_MAJOR,
  28. minor_version=MIN_OPEN_GL_MINOR)]
  29. for conf in confs:
  30. try:
  31. super(Viewer, self).__init__(config=conf, resizable=True,
  32. width=self._viewport_size[0],
  33. height=self._viewport_size[1])
  34. break
  35. except pyglet.window.NoSuchConfigException:
  36. pass
  37. if not self.context:
  38. raise ValueError('Unable to initialize an OpenGL 3+ context')
  39. clock.schedule_interval(
  40. Viewer._time_event, 1.0 / self.viewer_flags['refresh_rate'], self
  41. )
  42. self.switch_to()
  43. self.set_caption(self.viewer_flags['window_title'])
  44. class GlbVisualizer:
  45. def __init__(self, search_doc, matches: Optional[List]=None):
  46. self.search_doc = search_doc
  47. self.matches = matches
  48. self.orig_func = pyrender.Viewer._init_and_start_app
  49. pyrender.Viewer._init_and_start_app = _init_and_start_app
  50. def visualize(self):
  51. self.add(self.search_doc.uri, 'Query Doc')
  52. if self.matches:
  53. for i, match in enumerate(self.matches, start=1):
  54. self.add(match.uri, f'Top {i} Match')
  55. pyglet.app.run()
  56. def add(self, uri, title):
  57. fuze_trimesh = as_mesh(trimesh.load(uri))
  58. mesh = pyrender.Mesh.from_trimesh(fuze_trimesh)
  59. scene = pyrender.Scene()
  60. scene.add(mesh)
  61. pyrender.Viewer(
  62. scene,
  63. use_raymond_lighting=True,
  64. viewer_flags={
  65. 'rotate': True,
  66. 'window_title': title,
  67. 'caption': [{
  68. 'font_name': 'OpenSans-Regular',
  69. 'font_pt': 30,
  70. 'color': None,
  71. 'scale': 1.0,
  72. 'location': 4,
  73. 'text': title
  74. }]
  75. },
  76. )
  77. def __del__(self):
  78. pyrender.Viewer._init_and_start_app = self.orig_func

The visualizer uses pyrender to render the query and match results. Since we want to display multiple models at once, we need to patch the _init_and_start_app function to delay the start of pyrender app after all viewers are initialized.

Index, Search and Visualize Data

Download the pre-trained PNEncoder model here into model/ckpt. Also, store the index/search data in data/. We can then put the executors into a flow and use the flow to perform indexing and searching. Finally, we use the GlbVisualizer built earlier to visualize our data.

  1. with Flow().add(uses=GlbCrafter).add(uses=PNEncoder, uses_with={'ckpt_path': 'model/ckpt/ckpt_True'}).add(uses=MyIndexer) as f:
  2. f.index(from_files('data/*.glb'))
  3. results = f.search(Document(uri='data/rifle_16.glb'), return_results=True)
  4. doc = results[0].docs[0]
  5. # visualize top 3 matches, since we also index query doc, exclude the top 1 match as it is the query doc
  6. visualizer = GlbVisualizer(doc, matches=doc.matches[1:4]).visualize()

This is how the flow we built looks like:

../../../_images/flow.png

Putting it all together

Combining the steps listed above and import the necessary dependencies, the following is the complete code.

Complete source code

Search Similar 3D Meshes - 图2

Search Similar 3D Meshes - 图3

  1. from typing import Optional, List
  2. from jina import Flow, Executor, DocumentArray, Document, requests
  3. from jina.types.document.generators import from_files
  4. import trimesh
  5. import pyrender
  6. from pyrender import Viewer
  7. # pyglet dependencies should be imported AFTER pyrender
  8. import pyglet
  9. from pyglet import clock
  10. from pyglet.gl import Config
  11. def as_mesh(scene: trimesh.Scene) -> Optional[trimesh.Trimesh]:
  12. if len(scene.geometry) == 0:
  13. return None
  14. return trimesh.util.concatenate(
  15. tuple(
  16. trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
  17. for g in scene.geometry.values()
  18. )
  19. )
  20. class GlbCrafter(Executor):
  21. @requests(on=['/index', '/search'])
  22. def craft(self, docs: DocumentArray, **kwargs):
  23. for d in docs:
  24. mesh = trimesh.load_mesh(d.uri)
  25. d.blob = as_mesh(trimesh.load_mesh(d.uri)).sample(2048)
  26. def get_model(ckpt_path):
  27. import numpy as np
  28. import tensorflow as tf
  29. from tensorflow import keras
  30. from tensorflow.keras import layers
  31. def conv_bn(x, filters):
  32. x = layers.Conv1D(filters, kernel_size=1, padding='valid')(x)
  33. x = layers.BatchNormalization(momentum=0.0)(x)
  34. return layers.Activation('relu')(x)
  35. def dense_bn(x, filters):
  36. x = layers.Dense(filters)(x)
  37. x = layers.BatchNormalization(momentum=0.0)(x)
  38. return layers.Activation('relu')(x)
  39. def tnet(inputs, num_features):
  40. class OrthogonalRegularizer(keras.regularizers.Regularizer):
  41. def __init__(self, num_features_, l2reg=0.001):
  42. self.num_features = num_features_
  43. self.l2reg = l2reg
  44. self.eye = tf.eye(self.num_features)
  45. def __call__(self, x):
  46. x = tf.reshape(x, (-1, self.num_features, self.num_features))
  47. xxt = tf.tensordot(x, x, axes=(2, 2))
  48. xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
  49. return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))
  50. def get_config(self):
  51. return {
  52. 'num_features': self.num_features,
  53. 'l2reg': self.l2reg,
  54. 'eye': self.eye.numpy(),
  55. }
  56. bias = keras.initializers.Constant(np.eye(num_features).flatten())
  57. reg = OrthogonalRegularizer(num_features)
  58. x = conv_bn(inputs, 32)
  59. x = conv_bn(x, 64)
  60. x = conv_bn(x, 512)
  61. x = layers.GlobalMaxPooling1D()(x)
  62. x = dense_bn(x, 256)
  63. x = dense_bn(x, 128)
  64. x = layers.Dense(
  65. num_features * num_features,
  66. kernel_initializer='zeros',
  67. bias_initializer=bias,
  68. activity_regularizer=reg,
  69. )(x)
  70. feat_T = layers.Reshape((num_features, num_features))(x)
  71. return layers.Dot(axes=(2, 1))([inputs, feat_T])
  72. inputs = keras.Input(shape=(2048, 3))
  73. x = tnet(inputs, 3)
  74. x = conv_bn(x, 32)
  75. x = conv_bn(x, 32)
  76. x = tnet(x, 32)
  77. x = conv_bn(x, 32)
  78. x = conv_bn(x, 64)
  79. x = layers.GlobalMaxPooling1D()(x)
  80. x = dense_bn(x, 128)
  81. x = layers.Dropout(0.3)(x)
  82. outputs = layers.Dense(1, activation='softmax')(x)
  83. model = keras.Model(inputs=inputs, outputs=outputs, name='pointnet')
  84. intermediate_layer_model = keras.Model(
  85. inputs=model.input, outputs=model.get_layer(f'dense_1').output
  86. )
  87. intermediate_layer_model.load_weights(ckpt_path)
  88. return intermediate_layer_model
  89. class PNEncoder(Executor):
  90. def __init__(self, ckpt_path: str, **kwargs):
  91. super().__init__(**kwargs)
  92. self.embedding_model = get_model(ckpt_path=ckpt_path)
  93. @requests(on=['/index', '/search'])
  94. def encode(self, docs: DocumentArray, **kwargs):
  95. docs.embeddings = self.embedding_model.predict(docs.blobs)
  96. class MyIndexer(Executor):
  97. _docs = DocumentArray()
  98. @requests(on='/index')
  99. def index(self, docs: DocumentArray, **kwargs):
  100. self._docs.extend(docs)
  101. @requests(on='/search')
  102. def search(self, docs: DocumentArray, **kwargs):
  103. docs.match(self._docs, limit=5)
  104. def _init_and_start_app(self):
  105. TARGET_OPEN_GL_MAJOR = 4 # Target OpenGL Major Version
  106. TARGET_OPEN_GL_MINOR = 1
  107. MIN_OPEN_GL_MAJOR = 3 # Minimum OpenGL Major Version
  108. MIN_OPEN_GL_MINOR = 3 # Minimum OpenGL Minor Version
  109. confs = [
  110. Config(
  111. sample_buffers=1,
  112. samples=4,
  113. depth_size=24,
  114. double_buffer=True,
  115. major_version=TARGET_OPEN_GL_MAJOR,
  116. minor_version=TARGET_OPEN_GL_MINOR,
  117. ),
  118. Config(
  119. depth_size=24,
  120. double_buffer=True,
  121. major_version=TARGET_OPEN_GL_MAJOR,
  122. minor_version=TARGET_OPEN_GL_MINOR,
  123. ),
  124. Config(
  125. sample_buffers=1,
  126. samples=4,
  127. depth_size=24,
  128. double_buffer=True,
  129. major_version=MIN_OPEN_GL_MAJOR,
  130. minor_version=MIN_OPEN_GL_MINOR,
  131. ),
  132. Config(
  133. depth_size=24,
  134. double_buffer=True,
  135. major_version=MIN_OPEN_GL_MAJOR,
  136. minor_version=MIN_OPEN_GL_MINOR,
  137. ),
  138. ]
  139. for conf in confs:
  140. try:
  141. super(Viewer, self).__init__(
  142. config=conf,
  143. resizable=True,
  144. width=self._viewport_size[0],
  145. height=self._viewport_size[1],
  146. )
  147. break
  148. except pyglet.window.NoSuchConfigException:
  149. pass
  150. if not self.context:
  151. raise ValueError('Unable to initialize an OpenGL 3+ context')
  152. clock.schedule_interval(
  153. Viewer._time_event, 1.0 / self.viewer_flags['refresh_rate'], self
  154. )
  155. self.switch_to()
  156. self.set_caption(self.viewer_flags['window_title'])
  157. class GlbVisualizer:
  158. def __init__(self, search_doc, matches: Optional[List] = None):
  159. self.search_doc = search_doc
  160. self.matches = matches
  161. self.orig_func = pyrender.Viewer._init_and_start_app
  162. pyrender.Viewer._init_and_start_app = _init_and_start_app
  163. def visualize(self):
  164. self.add(self.search_doc.uri, 'Query Doc')
  165. if self.matches:
  166. for i, match in enumerate(self.matches, start=1):
  167. self.add(match.uri, f'Top {i} Match')
  168. pyglet.app.run()
  169. def add(self, uri, title):
  170. scene = pyrender.Scene()
  171. scene.add(pyrender.Mesh.from_trimesh(as_mesh(trimesh.load(uri))))
  172. pyrender.Viewer(
  173. scene,
  174. use_raymond_lighting=True,
  175. viewer_flags={
  176. 'rotate': True,
  177. 'window_title': title,
  178. 'caption': [
  179. {
  180. 'font_name': 'OpenSans-Regular',
  181. 'font_pt': 30,
  182. 'color': None,
  183. 'scale': 1.0,
  184. 'location': 4,
  185. 'text': title,
  186. }
  187. ],
  188. },
  189. )
  190. def __del__(self):
  191. pyrender.Viewer._init_and_start_app = self.orig_func
  192. with Flow().add(uses=GlbCrafter).add(uses=PNEncoder, uses_with={'ckpt_path': 'model/ckpt/ckpt_True'}).add(uses=MyIndexer) as f:
  193. f.index(from_files('data/*.glb'))
  194. results = f.search(Document(uri='data/rifle_16.glb'), return_results=True)
  195. doc = results[0].docs[0]
  196. visualizer = GlbVisualizer(doc, matches=doc.matches[1:4]).visualize()

Import warning

Note, pyrender has to be imported before all pyglet dependencies, otherwise an error will be raised in some os environments such as Mac OS.

Results

Now let’s take a look at the search results! Below is the rifle_16.glb 3D model we would like to search for:

../../../_images/query_doc.gif

And the following are the top 3 matches:

../../../_images/top_1.gif

../../../_images/top_2.gif

../../../_images/top_3.gif

Congratulations! You have just built a 3D Mesh Search Pipeline!