Gradio OpenAI Chatbot Webserver

Source vllm-project/vllm.

  1. 1import argparse
  2. 2
  3. 3import gradio as gr
  4. 4from openai import OpenAI
  5. 5
  6. 6# Argument parser setup
  7. 7parser = argparse.ArgumentParser(
  8. 8 description='Chatbot Interface with Customizable Parameters')
  9. 9parser.add_argument('--model-url',
  10. 10 type=str,
  11. 11 default='http://localhost:8000/v1',
  12. 12 help='Model URL')
  13. 13parser.add_argument('-m',
  14. 14 '--model',
  15. 15 type=str,
  16. 16 required=True,
  17. 17 help='Model name for the chatbot')
  18. 18parser.add_argument('--temp',
  19. 19 type=float,
  20. 20 default=0.8,
  21. 21 help='Temperature for text generation')
  22. 22parser.add_argument('--stop-token-ids',
  23. 23 type=str,
  24. 24 default='',
  25. 25 help='Comma-separated stop token IDs')
  26. 26parser.add_argument("--host", type=str, default=None)
  27. 27parser.add_argument("--port", type=int, default=8001)
  28. 28
  29. 29# Parse the arguments
  30. 30args = parser.parse_args()
  31. 31
  32. 32# Set OpenAI's API key and API base to use vLLM's API server.
  33. 33openai_api_key = "EMPTY"
  34. 34openai_api_base = args.model_url
  35. 35
  36. 36# Create an OpenAI client to interact with the API server
  37. 37client = OpenAI(
  38. 38 api_key=openai_api_key,
  39. 39 base_url=openai_api_base,
  40. 40)
  41. 41
  42. 42
  43. 43def predict(message, history):
  44. 44 # Convert chat history to OpenAI format
  45. 45 history_openai_format = [{
  46. 46 "role": "system",
  47. 47 "content": "You are a great ai assistant."
  48. 48 }]
  49. 49 for human, assistant in history:
  50. 50 history_openai_format.append({"role": "user", "content": human})
  51. 51 history_openai_format.append({
  52. 52 "role": "assistant",
  53. 53 "content": assistant
  54. 54 })
  55. 55 history_openai_format.append({"role": "user", "content": message})
  56. 56
  57. 57 # Create a chat completion request and send it to the API server
  58. 58 stream = client.chat.completions.create(
  59. 59 model=args.model, # Model name to use
  60. 60 messages=history_openai_format, # Chat history
  61. 61 temperature=args.temp, # Temperature for text generation
  62. 62 stream=True, # Stream response
  63. 63 extra_body={
  64. 64 'repetition_penalty':
  65. 65 1,
  66. 66 'stop_token_ids': [
  67. 67 int(id.strip()) for id in args.stop_token_ids.split(',')
  68. 68 if id.strip()
  69. 69 ] if args.stop_token_ids else []
  70. 70 })
  71. 71
  72. 72 # Read and return generated text from response stream
  73. 73 partial_message = ""
  74. 74 for chunk in stream:
  75. 75 partial_message += (chunk.choices[0].delta.content or "")
  76. 76 yield partial_message
  77. 77
  78. 78
  79. 79# Create and launch a chat interface with Gradio
  80. 80gr.ChatInterface(predict).queue().launch(server_name=args.host,
  81. 81 server_port=args.port,
  82. 82 share=True)