Lora With Quantization Inference

Source vllm-project/vllm.

  1. 1"""
  2. 2This example shows how to use LoRA with different quantization techniques
  3. 3for offline inference.
  4. 4
  5. 5Requires HuggingFace credentials for access.
  6. 6"""
  7. 7
  8. 8import gc
  9. 9from typing import List, Optional, Tuple
  10. 10
  11. 11import torch
  12. 12from huggingface_hub import snapshot_download
  13. 13
  14. 14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
  15. 15from vllm.lora.request import LoRARequest
  16. 16
  17. 17
  18. 18def create_test_prompts(
  19. 19 lora_path: str
  20. 20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
  21. 21 return [
  22. 22 # this is an example of using quantization without LoRA
  23. 23 ("My name is",
  24. 24 SamplingParams(temperature=0.0,
  25. 25 logprobs=1,
  26. 26 prompt_logprobs=1,
  27. 27 max_tokens=128), None),
  28. 28 # the next three examples use quantization with LoRA
  29. 29 ("my name is",
  30. 30 SamplingParams(temperature=0.0,
  31. 31 logprobs=1,
  32. 32 prompt_logprobs=1,
  33. 33 max_tokens=128),
  34. 34 LoRARequest("lora-test-1", 1, lora_path)),
  35. 35 ("The capital of USA is",
  36. 36 SamplingParams(temperature=0.0,
  37. 37 logprobs=1,
  38. 38 prompt_logprobs=1,
  39. 39 max_tokens=128),
  40. 40 LoRARequest("lora-test-2", 1, lora_path)),
  41. 41 ("The capital of France is",
  42. 42 SamplingParams(temperature=0.0,
  43. 43 logprobs=1,
  44. 44 prompt_logprobs=1,
  45. 45 max_tokens=128),
  46. 46 LoRARequest("lora-test-3", 1, lora_path)),
  47. 47 ]
  48. 48
  49. 49
  50. 50def process_requests(engine: LLMEngine,
  51. 51 test_prompts: List[Tuple[str, SamplingParams,
  52. 52 Optional[LoRARequest]]]):
  53. 53 """Continuously process a list of prompts and handle the outputs."""
  54. 54 request_id = 0
  55. 55
  56. 56 while test_prompts or engine.has_unfinished_requests():
  57. 57 if test_prompts:
  58. 58 prompt, sampling_params, lora_request = test_prompts.pop(0)
  59. 59 engine.add_request(str(request_id),
  60. 60 prompt,
  61. 61 sampling_params,
  62. 62 lora_request=lora_request)
  63. 63 request_id += 1
  64. 64
  65. 65 request_outputs: List[RequestOutput] = engine.step()
  66. 66 for request_output in request_outputs:
  67. 67 if request_output.finished:
  68. 68 print("----------------------------------------------------")
  69. 69 print(f"Prompt: {request_output.prompt}")
  70. 70 print(f"Output: {request_output.outputs[0].text}")
  71. 71
  72. 72
  73. 73def initialize_engine(model: str, quantization: str,
  74. 74 lora_repo: Optional[str]) -> LLMEngine:
  75. 75 """Initialize the LLMEngine."""
  76. 76
  77. 77 if quantization == "bitsandbytes":
  78. 78 # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
  79. 79 # It quantizes the model when loading, with some config info from the
  80. 80 # LoRA adapter repo. So need to set the parameter of load_format and
  81. 81 # qlora_adapter_name_or_path as below.
  82. 82 engine_args = EngineArgs(
  83. 83 model=model,
  84. 84 quantization=quantization,
  85. 85 qlora_adapter_name_or_path=lora_repo,
  86. 86 load_format="bitsandbytes",
  87. 87 enable_lora=True,
  88. 88 max_lora_rank=64,
  89. 89 # set it only in GPUs of limited memory
  90. 90 enforce_eager=True)
  91. 91 else:
  92. 92 engine_args = EngineArgs(
  93. 93 model=model,
  94. 94 quantization=quantization,
  95. 95 enable_lora=True,
  96. 96 max_loras=4,
  97. 97 # set it only in GPUs of limited memory
  98. 98 enforce_eager=True)
  99. 99 return LLMEngine.from_engine_args(engine_args)
  100. 100
  101. 101
  102. 102def main():
  103. 103 """Main function that sets up and runs the prompt processing."""
  104. 104
  105. 105 test_configs = [{
  106. 106 "name": "qlora_inference_example",
  107. 107 'model': "huggyllama/llama-7b",
  108. 108 'quantization': "bitsandbytes",
  109. 109 'lora_repo': 'timdettmers/qlora-flan-7b'
  110. 110 }, {
  111. 111 "name": "AWQ_inference_with_lora_example",
  112. 112 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
  113. 113 'quantization': "awq",
  114. 114 'lora_repo': 'jashing/tinyllama-colorist-lora'
  115. 115 }, {
  116. 116 "name": "GPTQ_inference_with_lora_example",
  117. 117 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
  118. 118 'quantization': "gptq",
  119. 119 'lora_repo': 'jashing/tinyllama-colorist-lora'
  120. 120 }]
  121. 121
  122. 122 for test_config in test_configs:
  123. 123 print(
  124. 124 f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
  125. 125 )
  126. 126 engine = initialize_engine(test_config['model'],
  127. 127 test_config['quantization'],
  128. 128 test_config['lora_repo'])
  129. 129 lora_path = snapshot_download(repo_id=test_config['lora_repo'])
  130. 130 test_prompts = create_test_prompts(lora_path)
  131. 131 process_requests(engine, test_prompts)
  132. 132
  133. 133 # Clean up the GPU memory for the next test
  134. 134 del engine
  135. 135 gc.collect()
  136. 136 torch.cuda.empty_cache()
  137. 137
  138. 138
  139. 139if __name__ == '__main__':
  140. 140 main()