Offline Inference Distributed

Source vllm-project/vllm.

  1. 1"""
  2. 2This example shows how to use Ray Data for running offline batch inference
  3. 3distributively on a multi-nodes cluster.
  4. 4
  5. 5Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
  6. 6"""
  7. 7
  8. 8from typing import Dict
  9. 9
  10. 10import numpy as np
  11. 11import ray
  12. 12from packaging.version import Version
  13. 13from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  14. 14
  15. 15from vllm import LLM, SamplingParams
  16. 16
  17. 17assert Version(ray.__version__) >= Version(
  18. 18 "2.22.0"), "Ray version must be at least 2.22.0"
  19. 19
  20. 20# Create a sampling params object.
  21. 21sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
  22. 22
  23. 23# Set tensor parallelism per instance.
  24. 24tensor_parallel_size = 1
  25. 25
  26. 26# Set number of instances. Each instance will use tensor_parallel_size GPUs.
  27. 27num_instances = 1
  28. 28
  29. 29
  30. 30# Create a class to do batch inference.
  31. 31class LLMPredictor:
  32. 32
  33. 33 def __init__(self):
  34. 34 # Create an LLM.
  35. 35 self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
  36. 36 tensor_parallel_size=tensor_parallel_size)
  37. 37
  38. 38 def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
  39. 39 # Generate texts from the prompts.
  40. 40 # The output is a list of RequestOutput objects that contain the prompt,
  41. 41 # generated text, and other information.
  42. 42 outputs = self.llm.generate(batch["text"], sampling_params)
  43. 43 prompt = []
  44. 44 generated_text = []
  45. 45 for output in outputs:
  46. 46 prompt.append(output.prompt)
  47. 47 generated_text.append(' '.join([o.text for o in output.outputs]))
  48. 48 return {
  49. 49 "prompt": prompt,
  50. 50 "generated_text": generated_text,
  51. 51 }
  52. 52
  53. 53
  54. 54# Read one text file from S3. Ray Data supports reading multiple files
  55. 55# from cloud storage (such as JSONL, Parquet, CSV, binary format).
  56. 56ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
  57. 57
  58. 58
  59. 59# For tensor_parallel_size > 1, we need to create placement groups for vLLM
  60. 60# to use. Every actor has to have its own placement group.
  61. 61def scheduling_strategy_fn():
  62. 62 # One bundle per tensor parallel worker
  63. 63 pg = ray.util.placement_group(
  64. 64 [{
  65. 65 "GPU": 1,
  66. 66 "CPU": 1
  67. 67 }] * tensor_parallel_size,
  68. 68 strategy="STRICT_PACK",
  69. 69 )
  70. 70 return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
  71. 71 pg, placement_group_capture_child_tasks=True))
  72. 72
  73. 73
  74. 74resources_kwarg = {}
  75. 75if tensor_parallel_size == 1:
  76. 76 # For tensor_parallel_size == 1, we simply set num_gpus=1.
  77. 77 resources_kwarg["num_gpus"] = 1
  78. 78else:
  79. 79 # Otherwise, we have to set num_gpus=0 and provide
  80. 80 # a function that will create a placement group for
  81. 81 # each instance.
  82. 82 resources_kwarg["num_gpus"] = 0
  83. 83 resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
  84. 84
  85. 85# Apply batch inference for all input data.
  86. 86ds = ds.map_batches(
  87. 87 LLMPredictor,
  88. 88 # Set the concurrency to the number of LLM instances.
  89. 89 concurrency=num_instances,
  90. 90 # Specify the batch size for inference.
  91. 91 batch_size=32,
  92. 92 **resources_kwarg,
  93. 93)
  94. 94
  95. 95# Peek first 10 results.
  96. 96# NOTE: This is for local testing and debugging. For production use case,
  97. 97# one should write full result out as shown below.
  98. 98outputs = ds.take(limit=10)
  99. 99for output in outputs:
  100. 100 prompt = output["prompt"]
  101. 101 generated_text = output["generated_text"]
  102. 102 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  103. 103
  104. 104# Write inference output data out as Parquet files to S3.
  105. 105# Multiple files would be written to the output destination,
  106. 106# and each task would write one or more files separately.
  107. 107#
  108. 108# ds.write_parquet("s3://<your-output-bucket>")