Save Sharded State

Source vllm-project/vllm.

  1. 1"""
  2. 2Saves each worker's model state dict directly to a checkpoint, which enables a
  3. 3fast load path for large tensor-parallel models where each worker only needs to
  4. 4read its own shard rather than the entire checkpoint.
  5. 5
  6. 6Example usage:
  7. 7
  8. 8python save_sharded_state.py \
  9. 9 --model /path/to/load \
  10. 10 --quantization deepspeedfp \
  11. 11 --tensor-parallel-size 8 \
  12. 12 --output /path/to/save
  13. 13
  14. 14Then, the model can be loaded with
  15. 15
  16. 16llm = LLM(
  17. 17 model="/path/to/save",
  18. 18 load_format="sharded_state",
  19. 19 quantization="deepspeedfp",
  20. 20 tensor_parallel_size=8,
  21. 21)
  22. 22"""
  23. 23import dataclasses
  24. 24import os
  25. 25import shutil
  26. 26from pathlib import Path
  27. 27
  28. 28from vllm import LLM, EngineArgs
  29. 29from vllm.utils import FlexibleArgumentParser
  30. 30
  31. 31parser = FlexibleArgumentParser()
  32. 32EngineArgs.add_cli_args(parser)
  33. 33parser.add_argument("--output",
  34. 34 "-o",
  35. 35 required=True,
  36. 36 type=str,
  37. 37 help="path to output checkpoint")
  38. 38parser.add_argument("--file-pattern",
  39. 39 type=str,
  40. 40 help="string pattern of saved filenames")
  41. 41parser.add_argument("--max-file-size",
  42. 42 type=str,
  43. 43 default=5 * 1024**3,
  44. 44 help="max size (in bytes) of each safetensors file")
  45. 45
  46. 46
  47. 47def main(args):
  48. 48 engine_args = EngineArgs.from_cli_args(args)
  49. 49 if engine_args.enable_lora:
  50. 50 raise ValueError("Saving with enable_lora=True is not supported!")
  51. 51 model_path = engine_args.model
  52. 52 if not Path(model_path).is_dir():
  53. 53 raise ValueError("model path must be a local directory")
  54. 54 # Create LLM instance from arguments
  55. 55 llm = LLM(**dataclasses.asdict(engine_args))
  56. 56 # Prepare output directory
  57. 57 Path(args.output).mkdir(exist_ok=True)
  58. 58 # Dump worker states to output directory
  59. 59 model_executor = llm.llm_engine.model_executor
  60. 60 model_executor.save_sharded_state(path=args.output,
  61. 61 pattern=args.file_pattern,
  62. 62 max_size=args.max_file_size)
  63. 63 # Copy metadata files to output directory
  64. 64 for file in os.listdir(model_path):
  65. 65 if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
  66. 66 if os.path.isdir(os.path.join(model_path, file)):
  67. 67 shutil.copytree(os.path.join(model_path, file),
  68. 68 os.path.join(args.output, file))
  69. 69 else:
  70. 70 shutil.copy(os.path.join(model_path, file), args.output)
  71. 71
  72. 72
  73. 73if __name__ == "__main__":
  74. 74 args = parser.parse_args()
  75. 75 main(args)