Save Sharded State
Source vllm-project/vllm.
1"""2Saves each worker's model state dict directly to a checkpoint, which enables a3fast load path for large tensor-parallel models where each worker only needs to4read its own shard rather than the entire checkpoint.56Example usage:78python save_sharded_state.py \9 --model /path/to/load \10 --quantization deepspeedfp \11 --tensor-parallel-size 8 \12 --output /path/to/save1314Then, the model can be loaded with1516llm = LLM(17 model="/path/to/save",18 load_format="sharded_state",19 quantization="deepspeedfp",20 tensor_parallel_size=8,21)22"""23import dataclasses24import os25import shutil26from pathlib import Path2728from vllm import LLM, EngineArgs29from vllm.utils import FlexibleArgumentParser3031parser = FlexibleArgumentParser()32EngineArgs.add_cli_args(parser)33parser.add_argument("--output",34 "-o",35 required=True,36 type=str,37 help="path to output checkpoint")38parser.add_argument("--file-pattern",39 type=str,40 help="string pattern of saved filenames")41parser.add_argument("--max-file-size",42 type=str,43 default=5 * 1024**3,44 help="max size (in bytes) of each safetensors file")454647def main(args):48 engine_args = EngineArgs.from_cli_args(args)49 if engine_args.enable_lora:50 raise ValueError("Saving with enable_lora=True is not supported!")51 model_path = engine_args.model52 if not Path(model_path).is_dir():53 raise ValueError("model path must be a local directory")54 # Create LLM instance from arguments55 llm = LLM(**dataclasses.asdict(engine_args))56 # Prepare output directory57 Path(args.output).mkdir(exist_ok=True)58 # Dump worker states to output directory59 model_executor = llm.llm_engine.model_executor60 model_executor.save_sharded_state(path=args.output,61 pattern=args.file_pattern,62 max_size=args.max_file_size)63 # Copy metadata files to output directory64 for file in os.listdir(model_path):65 if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):66 if os.path.isdir(os.path.join(model_path, file)):67 shutil.copytree(os.path.join(model_path, file),68 os.path.join(args.output, file))69 else:70 shutil.copy(os.path.join(model_path, file), args.output)717273if __name__ == "__main__":74 args = parser.parse_args()75 main(args)