MultiLoRA Inference

Source vllm-project/vllm.

  1. 1"""
  2. 2This example shows how to use the multi-LoRA functionality
  3. 3for offline inference.
  4. 4
  5. 5Requires HuggingFace credentials for access to Llama2.
  6. 6"""
  7. 7
  8. 8from typing import List, Optional, Tuple
  9. 9
  10. 10from huggingface_hub import snapshot_download
  11. 11
  12. 12from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
  13. 13from vllm.lora.request import LoRARequest
  14. 14
  15. 15
  16. 16def create_test_prompts(
  17. 17 lora_path: str
  18. 18) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
  19. 19 """Create a list of test prompts with their sampling parameters.
  20. 20
  21. 21 2 requests for base model, 4 requests for the LoRA. We define 2
  22. 22 different LoRA adapters (using the same model for demo purposes).
  23. 23 Since we also set `max_loras=1`, the expectation is that the requests
  24. 24 with the second LoRA adapter will be ran after all requests with the
  25. 25 first adapter have finished.
  26. 26 """
  27. 27 return [
  28. 28 ("A robot may not injure a human being",
  29. 29 SamplingParams(temperature=0.0,
  30. 30 logprobs=1,
  31. 31 prompt_logprobs=1,
  32. 32 max_tokens=128), None),
  33. 33 ("To be or not to be,",
  34. 34 SamplingParams(temperature=0.8,
  35. 35 top_k=5,
  36. 36 presence_penalty=0.2,
  37. 37 max_tokens=128), None),
  38. 38 (
  39. 39 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
  40. 40 SamplingParams(temperature=0.0,
  41. 41 logprobs=1,
  42. 42 prompt_logprobs=1,
  43. 43 max_tokens=128,
  44. 44 stop_token_ids=[32003]),
  45. 45 LoRARequest("sql-lora", 1, lora_path)),
  46. 46 (
  47. 47 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
  48. 48 SamplingParams(n=3,
  49. 49 best_of=3,
  50. 50 use_beam_search=True,
  51. 51 temperature=0,
  52. 52 max_tokens=128,
  53. 53 stop_token_ids=[32003]),
  54. 54 LoRARequest("sql-lora", 1, lora_path)),
  55. 55 (
  56. 56 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
  57. 57 SamplingParams(temperature=0.0,
  58. 58 logprobs=1,
  59. 59 prompt_logprobs=1,
  60. 60 max_tokens=128,
  61. 61 stop_token_ids=[32003]),
  62. 62 LoRARequest("sql-lora2", 2, lora_path)),
  63. 63 (
  64. 64 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
  65. 65 SamplingParams(n=3,
  66. 66 best_of=3,
  67. 67 use_beam_search=True,
  68. 68 temperature=0,
  69. 69 max_tokens=128,
  70. 70 stop_token_ids=[32003]),
  71. 71 LoRARequest("sql-lora", 1, lora_path)),
  72. 72 ]
  73. 73
  74. 74
  75. 75def process_requests(engine: LLMEngine,
  76. 76 test_prompts: List[Tuple[str, SamplingParams,
  77. 77 Optional[LoRARequest]]]):
  78. 78 """Continuously process a list of prompts and handle the outputs."""
  79. 79 request_id = 0
  80. 80
  81. 81 while test_prompts or engine.has_unfinished_requests():
  82. 82 if test_prompts:
  83. 83 prompt, sampling_params, lora_request = test_prompts.pop(0)
  84. 84 engine.add_request(str(request_id),
  85. 85 prompt,
  86. 86 sampling_params,
  87. 87 lora_request=lora_request)
  88. 88 request_id += 1
  89. 89
  90. 90 request_outputs: List[RequestOutput] = engine.step()
  91. 91
  92. 92 for request_output in request_outputs:
  93. 93 if request_output.finished:
  94. 94 print(request_output)
  95. 95
  96. 96
  97. 97def initialize_engine() -> LLMEngine:
  98. 98 """Initialize the LLMEngine."""
  99. 99 # max_loras: controls the number of LoRAs that can be used in the same
  100. 100 # batch. Larger numbers will cause higher memory usage, as each LoRA
  101. 101 # slot requires its own preallocated tensor.
  102. 102 # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
  103. 103 # numbers will cause higher memory usage. If you know that all LoRAs will
  104. 104 # use the same rank, it is recommended to set this as low as possible.
  105. 105 # max_cpu_loras: controls the size of the CPU LoRA cache.
  106. 106 engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
  107. 107 enable_lora=True,
  108. 108 max_loras=1,
  109. 109 max_lora_rank=8,
  110. 110 max_cpu_loras=2,
  111. 111 max_num_seqs=256)
  112. 112 return LLMEngine.from_engine_args(engine_args)
  113. 113
  114. 114
  115. 115def main():
  116. 116 """Main function that sets up and runs the prompt processing."""
  117. 117 engine = initialize_engine()
  118. 118 lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
  119. 119 test_prompts = create_test_prompts(lora_path)
  120. 120 process_requests(engine, test_prompts)
  121. 121
  122. 122
  123. 123if __name__ == '__main__':
  124. 124 main()