Offline Inference Distributed
Source vllm-project/vllm.
1"""2This example shows how to use Ray Data for running offline batch inference3distributively on a multi-nodes cluster.45Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html6"""78from typing import Any, Dict, List910import numpy as np11import ray12from packaging.version import Version13from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy1415from vllm import LLM, SamplingParams1617assert Version(ray.__version__) >= Version(18 "2.22.0"), "Ray version must be at least 2.22.0"1920# Create a sampling params object.21sampling_params = SamplingParams(temperature=0.8, top_p=0.95)2223# Set tensor parallelism per instance.24tensor_parallel_size = 12526# Set number of instances. Each instance will use tensor_parallel_size GPUs.27num_instances = 1282930# Create a class to do batch inference.31class LLMPredictor:3233 def __init__(self):34 # Create an LLM.35 self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",36 tensor_parallel_size=tensor_parallel_size)3738 def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:39 # Generate texts from the prompts.40 # The output is a list of RequestOutput objects that contain the prompt,41 # generated text, and other information.42 outputs = self.llm.generate(batch["text"], sampling_params)43 prompt: List[str] = []44 generated_text: List[str] = []45 for output in outputs:46 prompt.append(output.prompt)47 generated_text.append(' '.join([o.text for o in output.outputs]))48 return {49 "prompt": prompt,50 "generated_text": generated_text,51 }525354# Read one text file from S3. Ray Data supports reading multiple files55# from cloud storage (such as JSONL, Parquet, CSV, binary format).56ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")575859# For tensor_parallel_size > 1, we need to create placement groups for vLLM60# to use. Every actor has to have its own placement group.61def scheduling_strategy_fn():62 # One bundle per tensor parallel worker63 pg = ray.util.placement_group(64 [{65 "GPU": 1,66 "CPU": 167 }] * tensor_parallel_size,68 strategy="STRICT_PACK",69 )70 return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(71 pg, placement_group_capture_child_tasks=True))727374resources_kwarg: Dict[str, Any] = {}75if tensor_parallel_size == 1:76 # For tensor_parallel_size == 1, we simply set num_gpus=1.77 resources_kwarg["num_gpus"] = 178else:79 # Otherwise, we have to set num_gpus=0 and provide80 # a function that will create a placement group for81 # each instance.82 resources_kwarg["num_gpus"] = 083 resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn8485# Apply batch inference for all input data.86ds = ds.map_batches(87 LLMPredictor,88 # Set the concurrency to the number of LLM instances.89 concurrency=num_instances,90 # Specify the batch size for inference.91 batch_size=32,92 **resources_kwarg,93)9495# Peek first 10 results.96# NOTE: This is for local testing and debugging. For production use case,97# one should write full result out as shown below.98outputs = ds.take(limit=10)99for output in outputs:100 prompt = output["prompt"]101 generated_text = output["generated_text"]102 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")103104# Write inference output data out as Parquet files to S3.105# Multiple files would be written to the output destination,106# and each task would write one or more files separately.107#108# ds.write_parquet("s3://<your-output-bucket>")