Aqlm Example

Source vllm-project/vllm.

  1. 1import argparse
  2. 2
  3. 3from vllm import LLM, SamplingParams
  4. 4
  5. 5
  6. 6def main():
  7. 7
  8. 8 parser = argparse.ArgumentParser(description='AQLM examples')
  9. 9
  10. 10 parser.add_argument('--model',
  11. 11 '-m',
  12. 12 type=str,
  13. 13 default=None,
  14. 14 help='model path, as for HF')
  15. 15 parser.add_argument('--choice',
  16. 16 '-c',
  17. 17 type=int,
  18. 18 default=0,
  19. 19 help='known good models by index, [0-4]')
  20. 20 parser.add_argument('--tensor_parallel_size',
  21. 21 '-t',
  22. 22 type=int,
  23. 23 default=1,
  24. 24 help='tensor parallel size')
  25. 25
  26. 26 args = parser.parse_args()
  27. 27
  28. 28 models = [
  29. 29 "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
  30. 30 "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
  31. 31 "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
  32. 32 "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
  33. 33 "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
  34. 34 ]
  35. 35
  36. 36 model = LLM(args.model if args.model is not None else models[args.choice],
  37. 37 tensor_parallel_size=args.tensor_parallel_size)
  38. 38
  39. 39 sampling_params = SamplingParams(max_tokens=100, temperature=0)
  40. 40 outputs = model.generate("Hello my name is",
  41. 41 sampling_params=sampling_params)
  42. 42 print(outputs[0].outputs[0].text)
  43. 43
  44. 44
  45. 45if __name__ == '__main__':
  46. 46 main()