Search Image from Text via CLIP model

In this tutorial, we will create an image search system that retrieves images based on short text descriptions as query.

The interest behind this is that in regular search, image description or meta data describing the content of the image needs to be indexed first before retrieving the images via text query. This can be expensive because you need a person to write that description and also information about image content is not always available.

We need to look for another solution! What if we can directly compare text with images?

To do so, we need to figure out a way to match images and text. One way is finding related images with similar semantics to the query text. This requires us to represent both images and query text in the same embedding space to be able to do the matching. In this case, pre-trained cross-modal models can help us out.

For example when we write the word “dog” in query we want to be able to retrieve pictures with a dog solely by using the embeddings similarity.

Tip

The full source code of this tutorial is available in this Google Colab notebook

Now that we understand the problem and we have an idea on how to fix it, let’s try to imagine what the solution would look like:

  1. We have a bunch of images with no text description about the content.

  2. We use a model to create an embedding that represents those images.

  3. Now we will index and save our embeddings which we will call Documents inside a workspace folder.

This is what we call the index Flow.

../../../_images/index_flow_text2image.svg

Now to search for an image using text we do the following

  1. We embed the query text into the same embedding space as the image.

  2. We compute similarity between the query embedding and previously saved embeddings.

  3. We return the best results.

This is our query Flow.

../../../_images/query_flow_text2image.svg

If we had to build this from scratch, it would take a long time to build these Flows. Luckily we can leverage Jina’s tools such as Executors, Documents and Flows and build such a system easily.

Pre-requisites

Before we begin building our Flow we need to do a few things.

  • Install the following dependencies.
  1. pip install Pillow jina torch==1.9.0 torchvision==0.10.0 transformers==4.9.1 matplotlib [email protected]+https://github.com/jina-ai/jina-commons.git#egg=jina-commons

You can use the link or the following commands:

  1. wget https://open-images.s3.eu-central-1.amazonaws.com/data.zip
  2. unzip data.zip

You should find two folders after unzipping:

  • images: this folder contains the images that we will index.

  • query: this folder contains small images that we will use as search queries.

Building Executors

In this section, we will start developing the necessary Executors, for both query and index Flows.

To encode images and query text into the same space, we choose the pre-trained CLIP model from OpenAI.

What is CLIP?

The CLIP model is trained to learn visual concepts from natural languages. This is done using text snippets and image pairs across the internet. In the original CLIP paper, the model performs Zero Shot Learning by encoding text labels and images with separate models. Later the similarities between the encoded vectors are calculated.

In this tutorial, we use the image and the text encoding parts from CLIP to calculate the embeddings.

How does CLIP help?

Given a short text this is a dog, the CLIP text model can encode it into a vector. Meanwhile, the CLIP image model can encode one image of a dog and one image of a cat into the same vector space. We can further find the distance between the text vector and the vectors of the dog image is smaller than that between the same text and an image of a cat.

CLIPImageEncoder

This encoder encodes an image into embeddings using the CLIP model. We want an Executor that loads the CLIP model and encodes images during the index Flow.

Our Executor should:

  • Support both GPU and CPU: That’s why we will provision the device parameter and use it when encoding.

  • Be able to process Documents in batches in order to use our resources effectively: To do so, we will use the parameter batch_size

  1. from typing import Optional, Tuple
  2. import torch
  3. from jina import DocumentArray, Executor, requests
  4. from jina.logging.logger import JinaLogger
  5. from transformers import CLIPFeatureExtractor, CLIPModel
  6. class CLIPImageEncoder(Executor):
  7. """Encode image into embeddings using the CLIP model."""
  8. def __init__(
  9. self,
  10. pretrained_model_name_or_path: str = "openai/clip-vit-base-patch32",
  11. base_feature_extractor: Optional[str] = None,
  12. use_default_preprocessing: bool = True,
  13. device: str = "cpu",
  14. batch_size: int = 32,
  15. traversal_paths: Tuple = ("r",),
  16. *args,
  17. **kwargs,
  18. ):
  19. super().__init__(*args, **kwargs)
  20. self.batch_size = batch_size
  21. self.traversal_paths = traversal_paths
  22. self.pretrained_model_name_or_path = pretrained_model_name_or_path
  23. self.use_default_preprocessing = use_default_preprocessing
  24. self.base_feature_extractor = (
  25. base_feature_extractor or pretrained_model_name_or_path
  26. )
  27. self.device = device
  28. self.preprocessor = CLIPFeatureExtractor.from_pretrained(
  29. self.base_feature_extractor
  30. )
  31. self.model = CLIPModel.from_pretrained(self.pretrained_model_name_or_path)
  32. self.model.to(self.device).eval()
  33. @requests
  34. def encode(self, docs: Optional[DocumentArray], parameters: dict, **kwargs):
  35. if docs is None:
  36. return
  37. traversal_paths = parameters.get("traversal_paths", self.traversal_paths)
  38. batch_size = parameters.get("batch_size", self.batch_size)
  39. document_batches_generator = docs.traverse_flat(parameters.get('traversal_paths', self.traversal_paths)).batch(
  40. batch_size=batch_size
  41. )
  42. with torch.inference_mode():
  43. for batch_docs in document_batches_generator:
  44. blob_batch = [d.blob for d in batch_docs]
  45. if self.use_default_preprocessing:
  46. tensor = self._generate_input_features(blob_batch)
  47. else:
  48. tensor = {
  49. "pixel_values": torch.tensor(
  50. blob_batch, dtype=torch.float32, device=self.device
  51. )
  52. }
  53. embeddings = self.model.get_image_features(**tensor)
  54. embeddings = embeddings.cpu().numpy()
  55. for doc, embed in zip(batch_docs, embeddings):
  56. doc.embedding = embed
  57. def _generate_input_features(self, images):
  58. input_tokens = self.preprocessor(
  59. images=images,
  60. return_tensors="pt",
  61. )
  62. input_tokens = {
  63. k: v.to(torch.device(self.device)) for k, v in input_tokens.items()
  64. }
  65. return input_tokens

CLIPTextEncoder

This encoder encodes a text into embeddings using the CLIP model. We want an Executor that loads the CLIP model and encodes it during the query Flow.

Our Executor should:

  • Support both GPU and CPU: That’s why we will provision the device parameter and use it when encoding.

  • Be able to process Documents in batches in order to use our resources effectively: To do so, we will use the parameter batch_size

  1. from transformers import CLIPTokenizer
  2. class CLIPTextEncoder(Executor):
  3. """Encode text into embeddings using the CLIP model."""
  4. def __init__(
  5. self,
  6. pretrained_model_name_or_path: str = 'openai/clip-vit-base-patch32',
  7. base_tokenizer_model: Optional[str] = None,
  8. max_length: int = 77,
  9. device: str = 'cpu',
  10. traversal_paths: Sequence[str] = ['r'],
  11. batch_size: int = 32,
  12. *args,
  13. **kwargs,
  14. ):
  15. super().__init__(*args, **kwargs)
  16. self.traversal_paths = traversal_paths
  17. self.batch_size = batch_size
  18. self.pretrained_model_name_or_path = pretrained_model_name_or_path
  19. self.base_tokenizer_model = (
  20. base_tokenizer_model or pretrained_model_name_or_path
  21. )
  22. self.max_length = max_length
  23. self.device = device
  24. self.tokenizer = CLIPTokenizer.from_pretrained(self.base_tokenizer_model)
  25. self.model = CLIPModel.from_pretrained(self.pretrained_model_name_or_path)
  26. self.model.eval().to(device)
  27. @requests
  28. def encode(self, docs: Optional[DocumentArray], parameters: Dict, **kwargs):
  29. if docs is None:
  30. return
  31. for docs_batch in docs.traverse_flat(parameters.get('traversal_paths', self.traversal_paths)).batch(
  32. batch_size=parameters.get('batch_size', self.batch_size)
  33. ):
  34. text_batch = docs_batch.get_attributes('text')
  35. with torch.inference_mode():
  36. input_tokens = self._generate_input_tokens(text_batch)
  37. embeddings = self.model.get_text_features(**input_tokens).cpu().numpy()
  38. for doc, embedding in zip(docs_batch, embeddings):
  39. doc.embedding = embedding
  40. def _generate_input_tokens(self, texts: Sequence[str]):
  41. input_tokens = self.tokenizer(
  42. texts,
  43. max_length=self.max_length,
  44. padding='longest',
  45. truncation=True,
  46. return_tensors='pt',
  47. )
  48. input_tokens = {k: v.to(self.device) for k, v in input_tokens.items()}
  49. return input_tokens

SimpleIndexer

To implement SimpleIndexer, we can leverage Jina’s DocumentArrayMemmap. You can read about this data type here.

Our indexer will create an instance of DocumentArrayMemmap when it’s initialized. We want to store indexed Documents inside the workspace folder that’s why we pass the workspace attribute of the Executor to DocumentArrayMemmap.

To index, we implement the method index which has /index as the endpoint invoked during the index Flow. It’s as simple as extending the received docs to DocumentArrayMemmap instance.

On the other hand, for search, we implement the method search. We bind it to the query Flow using the decorator @requests(on='/search'). In Jina, searching for query Documents can be done by adding the results to the matches attribute of each query document. Since docs is a DocumentArray we can use method match to match query against the indexed Documents. Read more about match here.

  1. from typing import Dict, Optional
  2. from jina import DocumentArray, Executor, requests
  3. from jina.types.arrays.memmap import DocumentArrayMemmap
  4. class SimpleIndexer(Executor):
  5. """
  6. A simple indexer that stores all the Document data together,
  7. in a DocumentArrayMemmap object
  8. To be used as a unified indexer, combining both indexing and searching
  9. """
  10. def __init__(
  11. self,
  12. match_args: Optional[Dict] = None,
  13. **kwargs,
  14. ):
  15. """
  16. Initializer function for the simple indexer
  17. :param match_args: the arguments to `DocumentArray`'s match function
  18. """
  19. super().__init__(**kwargs)
  20. self._match_args = match_args or {}
  21. self._storage = DocumentArrayMemmap(
  22. self.workspace, key_length=kwargs.get('key_length', 64)
  23. )
  24. @requests(on='/index')
  25. def index(
  26. self,
  27. docs: Optional['DocumentArray'] = None,
  28. **kwargs,
  29. ):
  30. """All Documents to the DocumentArray
  31. :param docs: the docs to add
  32. """
  33. if docs:
  34. self._storage.extend(docs)
  35. @requests(on='/search')
  36. def search(
  37. self,
  38. docs: Optional['DocumentArray'] = None,
  39. parameters: Optional[Dict] = None,
  40. **kwargs,
  41. ):
  42. """Perform a vector similarity search and retrieve the full Document match
  43. :param docs: the Documents to search with
  44. :param parameters: the runtime arguments to `DocumentArray`'s match
  45. function. They overwrite the original match_args arguments.
  46. """
  47. if not docs:
  48. return
  49. match_args = deepcopy(self._match_args)
  50. if parameters:
  51. match_args.update(parameters)
  52. match_args = SimpleIndexer._filter_parameters(docs, match_args)
  53. docs.match(self._storage, **match_args)
  54. @staticmethod
  55. def _filter_parameters(docs, match_args):
  56. # get only those arguments that exist in .match
  57. args = set(inspect.getfullargspec(docs.match).args)
  58. args.discard('self')
  59. match_args = {k: v for k, v in match_args.items() if k in args}
  60. return match_args

Building Flows

Indexing

Now, after creating Executors, it’s time to use them in order to build an index Flow and index our data.

Building the index Flow

We create a Flow object and add Executors one after the other with the right parameters:

  1. CLIPImageEncoder: We specify the device.

  2. SimpleIndexer: We need to specify the workspace parameter.

  1. from jina import Flow
  2. flow_index = Flow() \
  3. .add(uses=CLIPImageEncoder, name="encoder", uses_with={"device":device}) \
  4. .add(uses=SimpleIndexer, name="indexer", workspace='workspace')
  5. flow_index.plot()

../../../_images/index_flow_text2image.svg

Now it’s time to index the dataset that we have downloaded. Actually, we will index images inside the images folder. This helper function will convert the image files into Documents, create a generator and yields Documents:

  1. import glob
  2. from jina import Document
  3. def input_docs(data_path):
  4. for fn in glob.glob(os.path.join(data_path, '*')):
  5. doc = Document(uri=fn, tags={'filename': fn})
  6. doc.load_uri_to_image_blob()
  7. yield doc

The final step in this section is to send the input Documents to the index Flow. Note that indexing can take a while:

  1. with flow_index:
  2. flow_index.post(on='/index',inputs=input_docs("/content/images"), request_size=1)
  1. [email protected][I]:🎉 Flow is ready to use!
  2. 🔗 Protocol: GRPC
  3. 🏠 Local access: 0.0.0.0:33367
  4. 🔒 Private network: 172.28.0.2:33367
  5. 🌐 Public address: 34.125.186.176:33367

Searching

Now, let’s build the search Flow and use it to search with sample query images.

Our Flow contains the following Executors:

  1. CLIPTextEncoder: We specify the device.

  2. SimpleIndexer: We need to specify the workspace parameter.

  1. flow_search = Flow() \
  2. .add(uses=CLIPTextEncoder, name="encoder", uses_with={"device":device}) \
  3. .add(uses=SimpleIndexer,name="indexer",workspace="workspace")
  4. flow_search.plot()

Query Flow:

../../../_images/query_flow_text2image.svg

We create a helper function to plot our images:

  1. import matplotlib.pyplot as plt
  2. def show_docs(docs):
  3. for doc in docs:
  4. plt.imshow(doc.blob)
  5. plt.show()

and one last function to show us the top three matches to our text query:

  1. def plot_search_results(resp: Request):
  2. for doc in resp.docs:
  3. print(f'Query text: {doc.text}')
  4. print(f'Matches:')
  5. print('-'*10)
  6. show_docs(doc.matches[:3])

Now we input some text queries which we transform into Documents and here are the results:

  1. with flow_search:
  2. resp = flow_search.post(on='/search',inputs=DocumentArray([
  3. Document(text='dog'),
  4. Document(text='cat'),
  5. Document(text='kids on their bikes'),
  6. ]),on_done=plot_search_results)

Sample results:

  1. Query: Dog
  2. Results:

../../../_images/dog1.png

../../../_images/dog2.png

../../../_images/dog3.png

  1. Query: Cat
  2. Results:

../../../_images/cat1.png

../../../_images/cat2.png

../../../_images/cat3.png

  1. Query: Kids riding bikes
  2. Results:

../../../_images/bike1.png

../../../_images/bike2.png

../../../_images/bike3.png

Congratulations! You have built a text-to-image search engine. You can check the full source code here and experiment with your own text queries.