Aqlm Example

Source vllm-project/vllm.

  1. 1from vllm import LLM, SamplingParams
  2. 2from vllm.utils import FlexibleArgumentParser
  3. 3
  4. 4
  5. 5def main():
  6. 6
  7. 7 parser = FlexibleArgumentParser(description='AQLM examples')
  8. 8
  9. 9 parser.add_argument('--model',
  10. 10 '-m',
  11. 11 type=str,
  12. 12 default=None,
  13. 13 help='model path, as for HF')
  14. 14 parser.add_argument('--choice',
  15. 15 '-c',
  16. 16 type=int,
  17. 17 default=0,
  18. 18 help='known good models by index, [0-4]')
  19. 19 parser.add_argument('--tensor-parallel-size',
  20. 20 '-t',
  21. 21 type=int,
  22. 22 default=1,
  23. 23 help='tensor parallel size')
  24. 24
  25. 25 args = parser.parse_args()
  26. 26
  27. 27 models = [
  28. 28 "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
  29. 29 "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
  30. 30 "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
  31. 31 "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
  32. 32 "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
  33. 33 ]
  34. 34
  35. 35 model = LLM(args.model if args.model is not None else models[args.choice],
  36. 36 tensor_parallel_size=args.tensor_parallel_size)
  37. 37
  38. 38 sampling_params = SamplingParams(max_tokens=100, temperature=0)
  39. 39 outputs = model.generate("Hello my name is",
  40. 40 sampling_params=sampling_params)
  41. 41 print(outputs[0].outputs[0].text)
  42. 42
  43. 43
  44. 44if __name__ == '__main__':
  45. 45 main()