API Client

Source vllm-project/vllm.

  1. 1"""Example Python client for vllm.entrypoints.api_server"""
  2. 2
  3. 3import argparse
  4. 4import json
  5. 5from typing import Iterable, List
  6. 6
  7. 7import requests
  8. 8
  9. 9
  10. 10def clear_line(n: int = 1) -> None:
  11. 11 LINE_UP = '\033[1A'
  12. 12 LINE_CLEAR = '\x1b[2K'
  13. 13 for _ in range(n):
  14. 14 print(LINE_UP, end=LINE_CLEAR, flush=True)
  15. 15
  16. 16
  17. 17def post_http_request(prompt: str,
  18. 18 api_url: str,
  19. 19 n: int = 1,
  20. 20 stream: bool = False) -> requests.Response:
  21. 21 headers = {"User-Agent": "Test Client"}
  22. 22 pload = {
  23. 23 "prompt": prompt,
  24. 24 "n": n,
  25. 25 "use_beam_search": True,
  26. 26 "temperature": 0.0,
  27. 27 "max_tokens": 16,
  28. 28 "stream": stream,
  29. 29 }
  30. 30 response = requests.post(api_url, headers=headers, json=pload, stream=True)
  31. 31 return response
  32. 32
  33. 33
  34. 34def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
  35. 35 for chunk in response.iter_lines(chunk_size=8192,
  36. 36 decode_unicode=False,
  37. 37 delimiter=b"\0"):
  38. 38 if chunk:
  39. 39 data = json.loads(chunk.decode("utf-8"))
  40. 40 output = data["text"]
  41. 41 yield output
  42. 42
  43. 43
  44. 44def get_response(response: requests.Response) -> List[str]:
  45. 45 data = json.loads(response.content)
  46. 46 output = data["text"]
  47. 47 return output
  48. 48
  49. 49
  50. 50if __name__ == "__main__":
  51. 51 parser = argparse.ArgumentParser()
  52. 52 parser.add_argument("--host", type=str, default="localhost")
  53. 53 parser.add_argument("--port", type=int, default=8000)
  54. 54 parser.add_argument("--n", type=int, default=4)
  55. 55 parser.add_argument("--prompt", type=str, default="San Francisco is a")
  56. 56 parser.add_argument("--stream", action="store_true")
  57. 57 args = parser.parse_args()
  58. 58 prompt = args.prompt
  59. 59 api_url = f"http://{args.host}:{args.port}/generate"
  60. 60 n = args.n
  61. 61 stream = args.stream
  62. 62
  63. 63 print(f"Prompt: {prompt!r}\n", flush=True)
  64. 64 response = post_http_request(prompt, api_url, n, stream)
  65. 65
  66. 66 if stream:
  67. 67 num_printed_lines = 0
  68. 68 for h in get_streaming_response(response):
  69. 69 clear_line(num_printed_lines)
  70. 70 num_printed_lines = 0
  71. 71 for i, line in enumerate(h):
  72. 72 num_printed_lines += 1
  73. 73 print(f"Beam candidate {i}: {line!r}", flush=True)
  74. 74 else:
  75. 75 output = get_response(response)
  76. 76 for i, line in enumerate(output):
  77. 77 print(f"Beam candidate {i}: {line!r}", flush=True)