Offline Inference With Prefix

Source vllm-project/vllm.

  1. 1from time import time
  2. 2
  3. 3from vllm import LLM, SamplingParams
  4. 4
  5. 5# Common prefix.
  6. 6prefix = (
  7. 7 "You are an expert school principal, skilled in effectively managing "
  8. 8 "faculty and staff. Draft 10-15 questions for a potential first grade "
  9. 9 "Head Teacher for my K-12, all-girls', independent school that emphasizes "
  10. 10 "community, joyful discovery, and life-long learning. The candidate is "
  11. 11 "coming in for a first-round panel interview for a 8th grade Math "
  12. 12 "teaching role. They have 5 years of previous teaching experience "
  13. 13 "as an assistant teacher at a co-ed, public school with experience "
  14. 14 "in middle school math teaching. Based on these information, fulfill "
  15. 15 "the following paragraph: ")
  16. 16
  17. 17# Sample prompts.
  18. 18prompts = [
  19. 19 "Hello, my name is",
  20. 20 "The president of the United States is",
  21. 21 "The capital of France is",
  22. 22 "The future of AI is",
  23. 23]
  24. 24
  25. 25generating_prompts = [prefix + prompt for prompt in prompts]
  26. 26
  27. 27# Create a sampling params object.
  28. 28sampling_params = SamplingParams(temperature=0.0)
  29. 29
  30. 30# Create an LLM.
  31. 31regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
  32. 32
  33. 33prefix_cached_llm = LLM(model="facebook/opt-125m",
  34. 34 enable_prefix_caching=True,
  35. 35 gpu_memory_utilization=0.4)
  36. 36print("Results without `enable_prefix_caching`")
  37. 37
  38. 38# Generate texts from the prompts. The output is a list of RequestOutput objects
  39. 39# that contain the prompt, generated text, and other information.
  40. 40start_time_regular = time()
  41. 41outputs = regular_llm.generate(generating_prompts, sampling_params)
  42. 42duration_regular = time() - start_time_regular
  43. 43
  44. 44regular_generated_texts = []
  45. 45# Print the outputs.
  46. 46for output in outputs:
  47. 47 prompt = output.prompt
  48. 48 generated_text = output.outputs[0].text
  49. 49 regular_generated_texts.append(generated_text)
  50. 50 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  51. 51
  52. 52print("-" * 80)
  53. 53
  54. 54# Warmup so that the shared prompt's KV cache is computed.
  55. 55prefix_cached_llm.generate(generating_prompts[0], sampling_params)
  56. 56
  57. 57# Generate with prefix caching.
  58. 58start_time_cached = time()
  59. 59outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
  60. 60duration_cached = time() - start_time_cached
  61. 61
  62. 62print("Results with `enable_prefix_caching`")
  63. 63
  64. 64cached_generated_texts = []
  65. 65# Print the outputs. You should see the same outputs as before.
  66. 66for output in outputs:
  67. 67 prompt = output.prompt
  68. 68 generated_text = output.outputs[0].text
  69. 69 cached_generated_texts.append(generated_text)
  70. 70 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  71. 71
  72. 72print("-" * 80)
  73. 73
  74. 74# Compare the results and display the speedup
  75. 75generated_same = all([
  76. 76 regular_generated_texts[i] == cached_generated_texts[i]
  77. 77 for i in range(len(prompts))
  78. 78])
  79. 79print(f"Generated answers are the same: {generated_same}")
  80. 80
  81. 81speedup = round(duration_regular / duration_cached, 2)
  82. 82print(f"Speed up of cached generation compared to the regular is: {speedup}")