MultiLoRA Inference
Source vllm-project/vllm.
1"""2This example shows how to use the multi-LoRA functionality3for offline inference.45Requires HuggingFace credentials for access to Llama2.6"""78from typing import List, Optional, Tuple910from huggingface_hub import snapshot_download1112from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams13from vllm.lora.request import LoRARequest141516def create_test_prompts(17 lora_path: str18) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:19 """Create a list of test prompts with their sampling parameters.2021 2 requests for base model, 4 requests for the LoRA. We define 222 different LoRA adapters (using the same model for demo purposes).23 Since we also set `max_loras=1`, the expectation is that the requests24 with the second LoRA adapter will be ran after all requests with the25 first adapter have finished.26 """27 return [28 ("A robot may not injure a human being",29 SamplingParams(temperature=0.0,30 logprobs=1,31 prompt_logprobs=1,32 max_tokens=128), None),33 ("To be or not to be,",34 SamplingParams(temperature=0.8,35 top_k=5,36 presence_penalty=0.2,37 max_tokens=128), None),38 (39 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E50140 SamplingParams(temperature=0.0,41 logprobs=1,42 prompt_logprobs=1,43 max_tokens=128,44 stop_token_ids=[32003]),45 LoRARequest("sql-lora", 1, lora_path)),46 (47 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E50148 SamplingParams(n=3,49 best_of=3,50 use_beam_search=True,51 temperature=0,52 max_tokens=128,53 stop_token_ids=[32003]),54 LoRARequest("sql-lora", 1, lora_path)),55 (56 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E50157 SamplingParams(temperature=0.0,58 logprobs=1,59 prompt_logprobs=1,60 max_tokens=128,61 stop_token_ids=[32003]),62 LoRARequest("sql-lora2", 2, lora_path)),63 (64 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E50165 SamplingParams(n=3,66 best_of=3,67 use_beam_search=True,68 temperature=0,69 max_tokens=128,70 stop_token_ids=[32003]),71 LoRARequest("sql-lora", 1, lora_path)),72 ]737475def process_requests(engine: LLMEngine,76 test_prompts: List[Tuple[str, SamplingParams,77 Optional[LoRARequest]]]):78 """Continuously process a list of prompts and handle the outputs."""79 request_id = 08081 while test_prompts or engine.has_unfinished_requests():82 if test_prompts:83 prompt, sampling_params, lora_request = test_prompts.pop(0)84 engine.add_request(str(request_id),85 prompt,86 sampling_params,87 lora_request=lora_request)88 request_id += 18990 request_outputs: List[RequestOutput] = engine.step()9192 for request_output in request_outputs:93 if request_output.finished:94 print(request_output)959697def initialize_engine() -> LLMEngine:98 """Initialize the LLMEngine."""99 # max_loras: controls the number of LoRAs that can be used in the same100 # batch. Larger numbers will cause higher memory usage, as each LoRA101 # slot requires its own preallocated tensor.102 # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger103 # numbers will cause higher memory usage. If you know that all LoRAs will104 # use the same rank, it is recommended to set this as low as possible.105 # max_cpu_loras: controls the size of the CPU LoRA cache.106 engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",107 enable_lora=True,108 max_loras=1,109 max_lora_rank=8,110 max_cpu_loras=2,111 max_num_seqs=256)112 return LLMEngine.from_engine_args(engine_args)113114115def main():116 """Main function that sets up and runs the prompt processing."""117 engine = initialize_engine()118 lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")119 test_prompts = create_test_prompts(lora_path)120 process_requests(engine, test_prompts)121122123if __name__ == '__main__':124 main()