Llava Example

Source vllm-project/vllm.

  1. 1import argparse
  2. 2import os
  3. 3import subprocess
  4. 4
  5. 5import torch
  6. 6
  7. 7from vllm import LLM
  8. 8from vllm.sequence import MultiModalData
  9. 9
  10. 10# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
  11. 11
  12. 12
  13. 13def run_llava_pixel_values():
  14. 14 llm = LLM(
  15. 15 model="llava-hf/llava-1.5-7b-hf",
  16. 16 image_input_type="pixel_values",
  17. 17 image_token_id=32000,
  18. 18 image_input_shape="1,3,336,336",
  19. 19 image_feature_size=576,
  20. 20 )
  21. 21
  22. 22 prompt = "<image>" * 576 + (
  23. 23 "\nUSER: What is the content of this image?\nASSISTANT:")
  24. 24
  25. 25 # This should be provided by another online or offline component.
  26. 26 images = torch.load("images/stop_sign_pixel_values.pt")
  27. 27
  28. 28 outputs = llm.generate(prompt,
  29. 29 multi_modal_data=MultiModalData(
  30. 30 type=MultiModalData.Type.IMAGE, data=images))
  31. 31 for o in outputs:
  32. 32 generated_text = o.outputs[0].text
  33. 33 print(generated_text)
  34. 34
  35. 35
  36. 36def run_llava_image_features():
  37. 37 llm = LLM(
  38. 38 model="llava-hf/llava-1.5-7b-hf",
  39. 39 image_input_type="image_features",
  40. 40 image_token_id=32000,
  41. 41 image_input_shape="1,576,1024",
  42. 42 image_feature_size=576,
  43. 43 )
  44. 44
  45. 45 prompt = "<image>" * 576 + (
  46. 46 "\nUSER: What is the content of this image?\nASSISTANT:")
  47. 47
  48. 48 # This should be provided by another online or offline component.
  49. 49 images = torch.load("images/stop_sign_image_features.pt")
  50. 50
  51. 51 outputs = llm.generate(prompt,
  52. 52 multi_modal_data=MultiModalData(
  53. 53 type=MultiModalData.Type.IMAGE, data=images))
  54. 54 for o in outputs:
  55. 55 generated_text = o.outputs[0].text
  56. 56 print(generated_text)
  57. 57
  58. 58
  59. 59def main(args):
  60. 60 if args.type == "pixel_values":
  61. 61 run_llava_pixel_values()
  62. 62 else:
  63. 63 run_llava_image_features()
  64. 64
  65. 65
  66. 66if __name__ == "__main__":
  67. 67 parser = argparse.ArgumentParser(description="Demo on Llava")
  68. 68 parser.add_argument("--type",
  69. 69 type=str,
  70. 70 choices=["pixel_values", "image_features"],
  71. 71 default="pixel_values",
  72. 72 help="image input type")
  73. 73 args = parser.parse_args()
  74. 74 # Download from s3
  75. 75 s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
  76. 76 local_directory = "images"
  77. 77
  78. 78 # Make sure the local directory exists or create it
  79. 79 os.makedirs(local_directory, exist_ok=True)
  80. 80
  81. 81 # Use AWS CLI to sync the directory, assume anonymous access
  82. 82 subprocess.check_call([
  83. 83 "aws",
  84. 84 "s3",
  85. 85 "sync",
  86. 86 s3_bucket_path,
  87. 87 local_directory,
  88. 88 "--no-sign-request",
  89. 89 ])
  90. 90 main(args)