Lora With Quantization Inference
Source vllm-project/vllm.
1"""2This example shows how to use LoRA with different quantization techniques3for offline inference.45Requires HuggingFace credentials for access.6"""78import gc9from typing import List, Optional, Tuple1011import torch12from huggingface_hub import snapshot_download1314from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams15from vllm.lora.request import LoRARequest161718def create_test_prompts(19 lora_path: str20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:21 return [22 # this is an example of using quantization without LoRA23 ("My name is",24 SamplingParams(temperature=0.0,25 logprobs=1,26 prompt_logprobs=1,27 max_tokens=128), None),28 # the next three examples use quantization with LoRA29 ("my name is",30 SamplingParams(temperature=0.0,31 logprobs=1,32 prompt_logprobs=1,33 max_tokens=128),34 LoRARequest("lora-test-1", 1, lora_path)),35 ("The capital of USA is",36 SamplingParams(temperature=0.0,37 logprobs=1,38 prompt_logprobs=1,39 max_tokens=128),40 LoRARequest("lora-test-2", 1, lora_path)),41 ("The capital of France is",42 SamplingParams(temperature=0.0,43 logprobs=1,44 prompt_logprobs=1,45 max_tokens=128),46 LoRARequest("lora-test-3", 1, lora_path)),47 ]484950def process_requests(engine: LLMEngine,51 test_prompts: List[Tuple[str, SamplingParams,52 Optional[LoRARequest]]]):53 """Continuously process a list of prompts and handle the outputs."""54 request_id = 05556 while test_prompts or engine.has_unfinished_requests():57 if test_prompts:58 prompt, sampling_params, lora_request = test_prompts.pop(0)59 engine.add_request(str(request_id),60 prompt,61 sampling_params,62 lora_request=lora_request)63 request_id += 16465 request_outputs: List[RequestOutput] = engine.step()66 for request_output in request_outputs:67 if request_output.finished:68 print("----------------------------------------------------")69 print(f"Prompt: {request_output.prompt}")70 print(f"Output: {request_output.outputs[0].text}")717273def initialize_engine(model: str, quantization: str,74 lora_repo: Optional[str]) -> LLMEngine:75 """Initialize the LLMEngine."""7677 if quantization == "bitsandbytes":78 # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.79 # It quantizes the model when loading, with some config info from the80 # LoRA adapter repo. So need to set the parameter of load_format and81 # qlora_adapter_name_or_path as below.82 engine_args = EngineArgs(83 model=model,84 quantization=quantization,85 qlora_adapter_name_or_path=lora_repo,86 load_format="bitsandbytes",87 enable_lora=True,88 max_lora_rank=64,89 # set it only in GPUs of limited memory90 enforce_eager=True)91 else:92 engine_args = EngineArgs(93 model=model,94 quantization=quantization,95 enable_lora=True,96 max_loras=4,97 # set it only in GPUs of limited memory98 enforce_eager=True)99 return LLMEngine.from_engine_args(engine_args)100101102def main():103 """Main function that sets up and runs the prompt processing."""104105 test_configs = [{106 "name": "qlora_inference_example",107 'model': "huggyllama/llama-7b",108 'quantization': "bitsandbytes",109 'lora_repo': 'timdettmers/qlora-flan-7b'110 }, {111 "name": "AWQ_inference_with_lora_example",112 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',113 'quantization': "awq",114 'lora_repo': 'jashing/tinyllama-colorist-lora'115 }, {116 "name": "GPTQ_inference_with_lora_example",117 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',118 'quantization': "gptq",119 'lora_repo': 'jashing/tinyllama-colorist-lora'120 }]121122 for test_config in test_configs:123 print(124 f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"125 )126 engine = initialize_engine(test_config['model'],127 test_config['quantization'],128 test_config['lora_repo'])129 lora_path = snapshot_download(repo_id=test_config['lora_repo'])130 test_prompts = create_test_prompts(lora_path)131 process_requests(engine, test_prompts)132133 # Clean up the GPU memory for the next test134 del engine135 gc.collect()136 torch.cuda.empty_cache()137138139if __name__ == '__main__':140 main()