Tensorize vLLM Model

Source vllm-project/vllm.

  1. 1import argparse
  2. 2import dataclasses
  3. 3import json
  4. 4import os
  5. 5import uuid
  6. 6from functools import partial
  7. 7
  8. 8from tensorizer import stream_io
  9. 9
  10. 10from vllm import LLM
  11. 11from vllm.distributed import (init_distributed_environment,
  12. 12 initialize_model_parallel)
  13. 13from vllm.engine.arg_utils import EngineArgs
  14. 14from vllm.engine.llm_engine import LLMEngine
  15. 15from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
  16. 16 TensorizerConfig,
  17. 17 serialize_vllm_model)
  18. 18
  19. 19# yapf conflicts with isort for this docstring
  20. 20# yapf: disable
  21. 21"""
  22. 22tensorize_vllm_model.py is a script that can be used to serialize and
  23. 23deserialize vLLM models. These models can be loaded using tensorizer
  24. 24to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
  25. 25or locally. Tensor encryption and decryption is also supported, although
  26. 26libsodium must be installed to use it. Install vllm with tensorizer support
  27. 27using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit
  28. 28https://github.com/coreweave/tensorizer
  29. 29
  30. 30To serialize a model, install vLLM from source, then run something
  31. 31like this from the root level of this repository:
  32. 32
  33. 33python -m examples.tensorize_vllm_model \
  34. 34 --model facebook/opt-125m \
  35. 35 serialize \
  36. 36 --serialized-directory s3://my-bucket \
  37. 37 --suffix v1
  38. 38
  39. 39Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
  40. 40and saves it to your S3 bucket. A local directory can also be used. This
  41. 41assumes your S3 credentials are specified as environment variables
  42. 42in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and
  43. 43`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide
  44. 44`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint`
  45. 45as CLI args to this script.
  46. 46
  47. 47You can also encrypt the model weights with a randomly-generated key by
  48. 48providing a `--keyfile` argument.
  49. 49
  50. 50To deserialize a model, you can run something like this from the root
  51. 51level of this repository:
  52. 52
  53. 53python -m examples.tensorize_vllm_model \
  54. 54 --model EleutherAI/gpt-j-6B \
  55. 55 --dtype float16 \
  56. 56 deserialize \
  57. 57 --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors
  58. 58
  59. 59Which downloads the model tensors from your S3 bucket and deserializes them.
  60. 60
  61. 61You can also provide a `--keyfile` argument to decrypt the model weights if
  62. 62they were serialized with encryption.
  63. 63
  64. 64For more information on the available arguments for serializing, run
  65. 65`python -m examples.tensorize_vllm_model serialize --help`.
  66. 66
  67. 67Or for deserializing:
  68. 68
  69. 69`python -m examples.tensorize_vllm_model deserialize --help`.
  70. 70
  71. 71Once a model is serialized, tensorizer can be invoked with the `LLM` class
  72. 72directly to load models:
  73. 73
  74. 74 llm = LLM(model="facebook/opt-125m",
  75. 75 load_format="tensorizer",
  76. 76 model_loader_extra_config=TensorizerConfig(
  77. 77 tensorizer_uri = path_to_tensors,
  78. 78 num_readers=3,
  79. 79 )
  80. 80 )
  81. 81
  82. 82A serialized model can be used during model loading for the vLLM OpenAI
  83. 83inference server. `model_loader_extra_config` is exposed as the CLI arg
  84. 84`--model-loader-extra-config`, and accepts a JSON string literal of the
  85. 85TensorizerConfig arguments desired.
  86. 86
  87. 87In order to see all of the available arguments usable to configure
  88. 88loading with tensorizer that are given to `TensorizerConfig`, run:
  89. 89
  90. 90`python -m examples.tensorize_vllm_model deserialize --help`
  91. 91
  92. 92under the `tensorizer options` section. These can also be used for
  93. 93deserialization in this example script, although `--tensorizer-uri` and
  94. 94`--path-to-tensors` are functionally the same in this case.
  95. 95"""
  96. 96
  97. 97
  98. 98def parse_args():
  99. 99 parser = argparse.ArgumentParser(
  100. 100 description="An example script that can be used to serialize and "
  101. 101 "deserialize vLLM models. These models "
  102. 102 "can be loaded using tensorizer directly to the GPU "
  103. 103 "extremely quickly. Tensor encryption and decryption is "
  104. 104 "also supported, although libsodium must be installed to "
  105. 105 "use it.")
  106. 106 parser = EngineArgs.add_cli_args(parser)
  107. 107 subparsers = parser.add_subparsers(dest='command')
  108. 108
  109. 109 serialize_parser = subparsers.add_parser(
  110. 110 'serialize', help="Serialize a model to `--serialized-directory`")
  111. 111
  112. 112 serialize_parser.add_argument(
  113. 113 "--suffix",
  114. 114 type=str,
  115. 115 required=False,
  116. 116 help=(
  117. 117 "The suffix to append to the serialized model directory, which is "
  118. 118 "used to construct the location of the serialized model tensors, "
  119. 119 "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
  120. 120 "`--suffix` is `v1`, the serialized model tensors will be "
  121. 121 "saved to "
  122. 122 "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
  123. 123 "If none is provided, a random UUID will be used."))
  124. 124 serialize_parser.add_argument(
  125. 125 "--serialized-directory",
  126. 126 type=str,
  127. 127 required=True,
  128. 128 help="The directory to serialize the model to. "
  129. 129 "This can be a local directory or S3 URI. The path to where the "
  130. 130 "tensors are saved is a combination of the supplied `dir` and model "
  131. 131 "reference ID. For instance, if `dir` is the serialized directory, "
  132. 132 "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
  133. 133 "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
  134. 134 "where `suffix` is given by `--suffix` or a random UUID if not "
  135. 135 "provided.")
  136. 136
  137. 137 serialize_parser.add_argument(
  138. 138 "--keyfile",
  139. 139 type=str,
  140. 140 required=False,
  141. 141 help=("Encrypt the model weights with a randomly-generated binary key,"
  142. 142 " and save the key at this path"))
  143. 143
  144. 144 deserialize_parser = subparsers.add_parser(
  145. 145 'deserialize',
  146. 146 help=("Deserialize a model from `--path-to-tensors`"
  147. 147 " to verify it can be loaded and used."))
  148. 148
  149. 149 deserialize_parser.add_argument(
  150. 150 "--path-to-tensors",
  151. 151 type=str,
  152. 152 required=True,
  153. 153 help="The local path or S3 URI to the model tensors to deserialize. ")
  154. 154
  155. 155 deserialize_parser.add_argument(
  156. 156 "--keyfile",
  157. 157 type=str,
  158. 158 required=False,
  159. 159 help=("Path to a binary key to use to decrypt the model weights,"
  160. 160 " if the model was serialized with encryption"))
  161. 161
  162. 162 TensorizerArgs.add_cli_args(deserialize_parser)
  163. 163
  164. 164 return parser.parse_args()
  165. 165
  166. 166
  167. 167
  168. 168def deserialize():
  169. 169 llm = LLM(model=args.model,
  170. 170 load_format="tensorizer",
  171. 171 model_loader_extra_config=tensorizer_config
  172. 172 )
  173. 173 return llm
  174. 174
  175. 175
  176. 176
  177. 177args = parse_args()
  178. 178
  179. 179s3_access_key_id = (getattr(args, 's3_access_key_id', None)
  180. 180 or os.environ.get("S3_ACCESS_KEY_ID", None))
  181. 181s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
  182. 182 or os.environ.get("S3_SECRET_ACCESS_KEY", None))
  183. 183s3_endpoint = (getattr(args, 's3_endpoint', None)
  184. 184 or os.environ.get("S3_ENDPOINT_URL", None))
  185. 185
  186. 186credentials = {
  187. 187 "s3_access_key_id": s3_access_key_id,
  188. 188 "s3_secret_access_key": s3_secret_access_key,
  189. 189 "s3_endpoint": s3_endpoint
  190. 190}
  191. 191
  192. 192_read_stream, _write_stream = (partial(
  193. 193 stream_io.open_stream,
  194. 194 mode=mode,
  195. 195 s3_access_key_id=s3_access_key_id,
  196. 196 s3_secret_access_key=s3_secret_access_key,
  197. 197 s3_endpoint=s3_endpoint,
  198. 198) for mode in ("rb", "wb+"))
  199. 199
  200. 200model_ref = args.model
  201. 201
  202. 202model_name = model_ref.split("/")[1]
  203. 203
  204. 204os.environ["MASTER_ADDR"] = "127.0.0.1"
  205. 205os.environ["MASTER_PORT"] = "8080"
  206. 206
  207. 207init_distributed_environment(world_size=1, rank=0, local_rank=0)
  208. 208initialize_model_parallel()
  209. 209
  210. 210keyfile = args.keyfile if args.keyfile else None
  211. 211
  212. 212
  213. 213if args.model_loader_extra_config:
  214. 214 config = json.loads(args.model_loader_extra_config)
  215. 215 tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
  216. 216 tensorizer_args.tensorizer_uri = args.path_to_tensors
  217. 217else:
  218. 218 tensorizer_args = None
  219. 219
  220. 220if args.command == "serialize":
  221. 221 eng_args_dict = {f.name: getattr(args, f.name) for f in
  222. 222 dataclasses.fields(EngineArgs)}
  223. 223
  224. 224 engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
  225. 225 engine = LLMEngine.from_engine_args(engine_args)
  226. 226
  227. 227 input_dir = args.serialized_directory.rstrip('/')
  228. 228 suffix = args.suffix if args.suffix else uuid.uuid4().hex
  229. 229 base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
  230. 230 model_path = f"{base_path}/model.tensors"
  231. 231 tensorizer_config = TensorizerConfig(
  232. 232 tensorizer_uri=model_path,
  233. 233 **credentials)
  234. 234 serialize_vllm_model(engine, tensorizer_config, keyfile)
  235. 235elif args.command == "deserialize":
  236. 236 if not tensorizer_args:
  237. 237 tensorizer_config = TensorizerConfig(
  238. 238 tensorizer_uri=args.path_to_tensors,
  239. 239 encryption_keyfile = keyfile,
  240. 240 **credentials
  241. 241 )
  242. 242 deserialize()
  243. 243else:
  244. 244 raise ValueError("Either serialize or deserialize must be specified.")