Tensorize vLLM Model

Source vllm-project/vllm.

  1. 1import argparse
  2. 2import dataclasses
  3. 3import os
  4. 4import time
  5. 5import uuid
  6. 6from functools import partial
  7. 7from typing import Type
  8. 8
  9. 9import torch
  10. 10import torch.nn as nn
  11. 11from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
  12. 12 TensorSerializer, stream_io)
  13. 13from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
  14. 14from transformers import AutoConfig, PretrainedConfig
  15. 15
  16. 16from vllm.distributed import initialize_model_parallel
  17. 17from vllm.engine.arg_utils import EngineArgs
  18. 18from vllm.engine.llm_engine import LLMEngine
  19. 19from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
  20. 20from vllm.model_executor.models import ModelRegistry
  21. 21
  22. 22# yapf conflicts with isort for this docstring
  23. 23# yapf: disable
  24. 24"""
  25. 25tensorize_vllm_model.py is a script that can be used to serialize and
  26. 26deserialize vLLM models. These models can be loaded using tensorizer
  27. 27to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
  28. 28or locally. Tensor encryption and decryption is also supported, although
  29. 29libsodium must be installed to use it. Install vllm with tensorizer support
  30. 30using `pip install vllm[tensorizer]`.
  31. 31
  32. 32To serialize a model, install vLLM from source, then run something
  33. 33like this from the root level of this repository:
  34. 34
  35. 35python -m examples.tensorize_vllm_model \
  36. 36 --model EleutherAI/gpt-j-6B \
  37. 37 --dtype float16 \
  38. 38 serialize \
  39. 39 --serialized-directory s3://my-bucket/ \
  40. 40 --suffix vllm
  41. 41
  42. 42Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
  43. 43and saves it to your S3 bucket. A local directory can also be used. This
  44. 44assumes your S3 credentials are specified as environment variables
  45. 45in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
  46. 46To provide S3 credentials directly, you can provide `--s3-access-key-id` and
  47. 47`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
  48. 48script.
  49. 49
  50. 50You can also encrypt the model weights with a randomly-generated key by
  51. 51providing a `--keyfile` argument.
  52. 52
  53. 53To deserialize a model, you can run something like this from the root
  54. 54level of this repository:
  55. 55
  56. 56python -m examples.tensorize_vllm_model \
  57. 57 --model EleutherAI/gpt-j-6B \
  58. 58 --dtype float16 \
  59. 59 deserialize \
  60. 60 --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
  61. 61
  62. 62Which downloads the model tensors from your S3 bucket and deserializes them.
  63. 63
  64. 64You can also provide a `--keyfile` argument to decrypt the model weights if
  65. 65they were serialized with encryption.
  66. 66
  67. 67For more information on the available arguments for serializing, run
  68. 68`python -m examples.tensorize_vllm_model serialize --help`.
  69. 69
  70. 70Or for deserializing:
  71. 71
  72. 72`python -m examples.tensorize_vllm_model deserialize --help`.
  73. 73
  74. 74Once a model is serialized, it can be used to load the model when running the
  75. 75OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
  76. 76the `--tensorizer-uri` CLI argument that is functionally the same as the
  77. 77`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
  78. 78signify that the model to be deserialized is a vLLM model, rather than a
  79. 79HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
  80. 80in the same inference server, albeit without the speed optimizations. To
  81. 81deserialize an encrypted file, the `--encryption-keyfile` argument can be used
  82. 82to provide the path to the keyfile used to encrypt the model weights. For
  83. 83information on all the arguments that can be used to configure tensorizer's
  84. 84deserialization, check out the tensorizer options argument group in the
  85. 85`vllm/entrypoints/openai/api_server.py` script with `--help`.
  86. 86
  87. 87Tensorizer can also be invoked with the `LLM` class directly to load models:
  88. 88
  89. 89 llm = LLM(model="facebook/opt-125m",
  90. 90 load_format="tensorizer",
  91. 91 tensorizer_uri=path_to_opt_tensors,
  92. 92 num_readers=3,
  93. 93 vllm_tensorized=True)
  94. 94"""
  95. 95
  96. 96
  97. 97def parse_args():
  98. 98 parser = argparse.ArgumentParser(
  99. 99 description="An example script that can be used to serialize and "
  100. 100 "deserialize vLLM models. These models "
  101. 101 "can be loaded using tensorizer directly to the GPU "
  102. 102 "extremely quickly. Tensor encryption and decryption is "
  103. 103 "also supported, although libsodium must be installed to "
  104. 104 "use it.")
  105. 105 parser = EngineArgs.add_cli_args(parser)
  106. 106 subparsers = parser.add_subparsers(dest='command')
  107. 107
  108. 108 serialize_parser = subparsers.add_parser(
  109. 109 'serialize', help="Serialize a model to `--serialized-directory`")
  110. 110
  111. 111 serialize_parser.add_argument(
  112. 112 "--suffix",
  113. 113 type=str,
  114. 114 required=False,
  115. 115 help=(
  116. 116 "The suffix to append to the serialized model directory, which is "
  117. 117 "used to construct the location of the serialized model tensors, "
  118. 118 "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
  119. 119 "`--suffix` is `v1`, the serialized model tensors will be "
  120. 120 "saved to "
  121. 121 "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
  122. 122 "If none is provided, a random UUID will be used."))
  123. 123 serialize_parser.add_argument(
  124. 124 "--serialized-directory",
  125. 125 type=str,
  126. 126 required=True,
  127. 127 help="The directory to serialize the model to. "
  128. 128 "This can be a local directory or S3 URI. The path to where the "
  129. 129 "tensors are saved is a combination of the supplied `dir` and model "
  130. 130 "reference ID. For instance, if `dir` is the serialized directory, "
  131. 131 "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
  132. 132 "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
  133. 133 "where `suffix` is given by `--suffix` or a random UUID if not "
  134. 134 "provided.")
  135. 135
  136. 136 serialize_parser.add_argument(
  137. 137 "--keyfile",
  138. 138 type=str,
  139. 139 required=False,
  140. 140 help=("Encrypt the model weights with a randomly-generated binary key,"
  141. 141 " and save the key at this path"))
  142. 142
  143. 143 deserialize_parser = subparsers.add_parser(
  144. 144 'deserialize',
  145. 145 help=("Deserialize a model from `--path-to-tensors`"
  146. 146 " to verify it can be loaded and used."))
  147. 147
  148. 148 deserialize_parser.add_argument(
  149. 149 "--path-to-tensors",
  150. 150 type=str,
  151. 151 required=True,
  152. 152 help="The local path or S3 URI to the model tensors to deserialize. ")
  153. 153
  154. 154 deserialize_parser.add_argument(
  155. 155 "--keyfile",
  156. 156 type=str,
  157. 157 required=False,
  158. 158 help=("Path to a binary key to use to decrypt the model weights,"
  159. 159 " if the model was serialized with encryption"))
  160. 160
  161. 161 return parser.parse_args()
  162. 162
  163. 163
  164. 164def make_model_contiguous(model):
  165. 165 # Ensure tensors are saved in memory contiguously
  166. 166 for param in model.parameters():
  167. 167 param.data = param.data.contiguous()
  168. 168
  169. 169
  170. 170def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
  171. 171 architectures = getattr(config, "architectures", [])
  172. 172 for arch in architectures:
  173. 173 model_cls = ModelRegistry.load_model_cls(arch)
  174. 174 if model_cls is not None:
  175. 175 return model_cls
  176. 176 raise ValueError(
  177. 177 f"Model architectures {architectures} are not supported for now. "
  178. 178 f"Supported architectures: {ModelRegistry.get_supported_archs()}")
  179. 179
  180. 180
  181. 181def serialize():
  182. 182
  183. 183 eng_args_dict = {f.name: getattr(args, f.name) for f in
  184. 184 dataclasses.fields(EngineArgs)}
  185. 185 engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
  186. 186 engine = LLMEngine.from_engine_args(engine_args)
  187. 187
  188. 188 model = (engine.model_executor.driver_worker.
  189. 189 model_runner.model)
  190. 190
  191. 191 encryption_params = EncryptionParams.random() if keyfile else None
  192. 192 if keyfile:
  193. 193 with _write_stream(keyfile) as stream:
  194. 194 stream.write(encryption_params.key)
  195. 195
  196. 196 with _write_stream(model_path) as stream:
  197. 197 serializer = TensorSerializer(stream, encryption=encryption_params)
  198. 198 serializer.write_module(model)
  199. 199 serializer.close()
  200. 200
  201. 201 print("Serialization complete. Model tensors saved to", model_path)
  202. 202 if keyfile:
  203. 203 print("Key saved to", keyfile)
  204. 204
  205. 205
  206. 206def deserialize():
  207. 207 config = AutoConfig.from_pretrained(model_ref)
  208. 208
  209. 209 with no_init_or_tensor():
  210. 210 model_class = _get_vllm_model_architecture(config)
  211. 211 model = model_class(config)
  212. 212
  213. 213 before_mem = get_mem_usage()
  214. 214 start = time.time()
  215. 215
  216. 216 if keyfile:
  217. 217 with _read_stream(keyfile) as stream:
  218. 218 key = stream.read()
  219. 219 decryption_params = DecryptionParams.from_key(key)
  220. 220 tensorizer_args.deserializer_params['encryption'] = \
  221. 221 decryption_params
  222. 222
  223. 223 with (_read_stream(model_path)) as stream, TensorDeserializer(
  224. 224 stream, **tensorizer_args.deserializer_params) as deserializer:
  225. 225 deserializer.load_into_module(model)
  226. 226 end = time.time()
  227. 227
  228. 228 # Brag about how fast we are.
  229. 229 total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
  230. 230 duration = end - start
  231. 231 per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
  232. 232 after_mem = get_mem_usage()
  233. 233 print(
  234. 234 f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
  235. 235 )
  236. 236 print(f"Memory usage before: {before_mem}")
  237. 237 print(f"Memory usage after: {after_mem}")
  238. 238
  239. 239 return model
  240. 240
  241. 241
  242. 242args = parse_args()
  243. 243
  244. 244s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
  245. 245 or None)
  246. 246s3_secret_access_key = (args.s3_secret_access_key
  247. 247 or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
  248. 248
  249. 249s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
  250. 250
  251. 251_read_stream, _write_stream = (partial(
  252. 252 stream_io.open_stream,
  253. 253 mode=mode,
  254. 254 s3_access_key_id=s3_access_key_id,
  255. 255 s3_secret_access_key=s3_secret_access_key,
  256. 256 s3_endpoint=s3_endpoint,
  257. 257) for mode in ("rb", "wb+"))
  258. 258
  259. 259model_ref = args.model
  260. 260
  261. 261model_name = model_ref.split("/")[1]
  262. 262
  263. 263os.environ["MASTER_ADDR"] = "127.0.0.1"
  264. 264os.environ["MASTER_PORT"] = "8080"
  265. 265
  266. 266torch.distributed.init_process_group(world_size=1, rank=0)
  267. 267initialize_model_parallel()
  268. 268
  269. 269keyfile = args.keyfile if args.keyfile else None
  270. 270
  271. 271if args.command == "serialize":
  272. 272 input_dir = args.serialized_directory.rstrip('/')
  273. 273 suffix = args.suffix if args.suffix else uuid.uuid4().hex
  274. 274 base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
  275. 275 model_path = f"{base_path}/model.tensors"
  276. 276 serialize()
  277. 277elif args.command == "deserialize":
  278. 278 tensorizer_args = TensorizerArgs.from_cli_args(args)
  279. 279 model_path = args.path_to_tensors
  280. 280 deserialize()
  281. 281else:
  282. 282 raise ValueError("Either serialize or deserialize must be specified.")