Offline Inference Mlpspeculator

Source vllm-project/vllm.

  1. 1import gc
  2. 2import time
  3. 3from typing import List
  4. 4
  5. 5from vllm import LLM, SamplingParams
  6. 6
  7. 7
  8. 8def time_generation(llm: LLM, prompts: List[str],
  9. 9 sampling_params: SamplingParams):
  10. 10 # Generate texts from the prompts. The output is a list of RequestOutput
  11. 11 # objects that contain the prompt, generated text, and other information.
  12. 12 # Warmup first
  13. 13 llm.generate(prompts, sampling_params)
  14. 14 llm.generate(prompts, sampling_params)
  15. 15 start = time.time()
  16. 16 outputs = llm.generate(prompts, sampling_params)
  17. 17 end = time.time()
  18. 18 print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
  19. 19 # Print the outputs.
  20. 20 for output in outputs:
  21. 21 generated_text = output.outputs[0].text
  22. 22 print(f"text: {generated_text!r}")
  23. 23
  24. 24
  25. 25if __name__ == "__main__":
  26. 26
  27. 27 template = (
  28. 28 "Below is an instruction that describes a task. Write a response "
  29. 29 "that appropriately completes the request.\n\n### Instruction:\n{}"
  30. 30 "\n\n### Response:\n")
  31. 31
  32. 32 # Sample prompts.
  33. 33 prompts = [
  34. 34 "Write about the president of the United States.",
  35. 35 ]
  36. 36 prompts = [template.format(prompt) for prompt in prompts]
  37. 37 # Create a sampling params object.
  38. 38 sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
  39. 39
  40. 40 # Create an LLM without spec decoding
  41. 41 llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
  42. 42
  43. 43 print("Without speculation")
  44. 44 time_generation(llm, prompts, sampling_params)
  45. 45
  46. 46 del llm
  47. 47 gc.collect()
  48. 48
  49. 49 # Create an LLM with spec decoding
  50. 50 llm = LLM(
  51. 51 model="meta-llama/Llama-2-13b-chat-hf",
  52. 52 speculative_model="ibm-fms/llama-13b-accelerator",
  53. 53 # These are currently required for MLPSpeculator decoding
  54. 54 use_v2_block_manager=True,
  55. 55 )
  56. 56
  57. 57 print("With speculation")
  58. 58 time_generation(llm, prompts, sampling_params)