HFTrainer

pipeline pipeline

Trains a new Hugging Face Transformer model using the Trainer framework.

Example

The following shows a simple example using this pipeline.

  1. import pandas as pd
  2. from datasets import load_dataset
  3. from txtai.pipeline import HFTrainer
  4. trainer = HFTrainer()
  5. # Pandas DataFrame
  6. df = pd.read_csv("training.csv")
  7. model, tokenizer = trainer("bert-base-uncased", df)
  8. # Hugging Face dataset
  9. ds = load_dataset("glue", "sst2")
  10. model, tokenizer = trainer("bert-base-uncased", ds["train"], columns=("sentence", "label"))
  11. # List of dicts
  12. dt = [{"text": "sentence 1", "label": 0}, {"text": "sentence 2", "label": 1}]]
  13. model, tokenizer = trainer("bert-base-uncased", dt)
  14. # Support additional TrainingArguments
  15. model, tokenizer = trainer("bert-base-uncased", dt,
  16. learning_rate=3e-5, num_train_epochs=5)

All TrainingArguments are supported as function arguments to the trainer call.

See the links below for more detailed examples.

NotebookDescription
Train a text labelerBuild text sequence classification modelsOpen In Colab
Train without labelsUse zero-shot classifiers to train new modelsOpen In Colab
Train a QA modelBuild and fine-tune question-answering modelsOpen In Colab
Train a language model from scratchBuild new language modelsOpen In Colab

Training tasks

The HFTrainer pipeline builds and/or fine-tunes models for following training tasks.

TaskDescription
language-generationCausal language model for text generation (e.g. GPT)
language-modelingMasked language model for general tasks (e.g. BERT)
question-answeringExtractive question-answering model, typically with the SQuAD dataset
sequence-sequenceSequence-Sequence model (e.g. T5)
text-classificationClassify text with a set of labels
token-detectionELECTRA-style pre-training with replaced token detection

PEFT

Parameter-Efficient Fine-Tuning (PEFT) is supported through Hugging Face’s PEFT library. Quantization is provided through bitsandbytes. See the examples below.

  1. from txtai.pipeline import HFTrainer
  2. trainer = HFTrainer()
  3. trainer(..., quantize=True, lora=True)

When these parameters are set to True, they use default configuration. This can also be customized.

  1. quantize = {
  2. "load_in_4bit": True,
  3. "bnb_4bit_use_double_quant": True,
  4. "bnb_4bit_quant_type": "nf4",
  5. "bnb_4bit_compute_dtype": "bfloat16"
  6. }
  7. lora = {
  8. "r": 16,
  9. "lora_alpha": 8,
  10. "target_modules": "all-linear",
  11. "lora_dropout": 0.05,
  12. "bias": "none"
  13. }
  14. trainer(..., quantize=quantize, lora=lora)

The parameters also accept transformers.BitsAndBytesConfig and peft.LoraConfig instances.

See the following PEFT documentation links for more information.

Methods

Python documentation for the pipeline.

Builds a new model using arguments.

Parameters:

NameTypeDescriptionDefault
base

path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple

required
train

training data

required
validation

validation data

None
columns

tuple of columns to use for text/label, defaults to (text, None, label)

None
maxlength

maximum sequence length, defaults to tokenizer.model_max_length

None
stride

chunk size for splitting data for QA tasks

128
task

optional model task or category, determines the model type, defaults to “text-classification”

‘text-classification’
prefix

optional source prefix

None
metrics

optional function that computes and returns a dict of evaluation metrics

None
tokenizers

optional number of concurrent tokenizers, defaults to None

None
checkpoint

optional resume from checkpoint flag or path to checkpoint directory, defaults to None

None
quantize

quantization configuration to pass to base model

None
lora

lora configuration to pass to PEFT model

None
args

training arguments

{}

Returns:

TypeDescription

(model, tokenizer)

Source code in txtai/pipeline/train/hftrainer.py

  1. 45
  2. 46
  3. 47
  4. 48
  5. 49
  6. 50
  7. 51
  8. 52
  9. 53
  10. 54
  11. 55
  12. 56
  13. 57
  14. 58
  15. 59
  16. 60
  17. 61
  18. 62
  19. 63
  20. 64
  21. 65
  22. 66
  23. 67
  24. 68
  25. 69
  26. 70
  27. 71
  28. 72
  29. 73
  30. 74
  31. 75
  32. 76
  33. 77
  34. 78
  35. 79
  36. 80
  37. 81
  38. 82
  39. 83
  40. 84
  41. 85
  42. 86
  43. 87
  44. 88
  45. 89
  46. 90
  47. 91
  48. 92
  49. 93
  50. 94
  51. 95
  52. 96
  53. 97
  54. 98
  55. 99
  56. 100
  57. 101
  58. 102
  59. 103
  60. 104
  61. 105
  62. 106
  63. 107
  64. 108
  65. 109
  66. 110
  67. 111
  68. 112
  69. 113
  70. 114
  71. 115
  72. 116
  73. 117
  74. 118
  75. 119
  76. 120
  77. 121
  78. 122
  79. 123
  80. 124
  81. 125
  82. 126
  83. 127
  84. 128
  85. 129
  86. 130
  87. 131
  88. 132
  89. 133
  90. 134
  91. 135
  92. 136
  93. 137
  94. 138
  95. 139
  96. 140
  97. 141
  98. 142
  99. 143
  100. 144
  1. def call(
  2. self,
  3. base,
  4. train,
  5. validation=None,
  6. columns=None,
  7. maxlength=None,
  8. stride=128,
  9. task=”text-classification”,
  10. prefix=None,
  11. metrics=None,
  12. tokenizers=None,
  13. checkpoint=None,
  14. quantize=None,
  15. lora=None,
  16. **args
  17. ):
  18. “””
  19. Builds a new model using arguments.
  20. Args:
  21. base: path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple
  22. train: training data
  23. validation: validation data
  24. columns: tuple of columns to use for text/label, defaults to (text, None, label)
  25. maxlength: maximum sequence length, defaults to tokenizer.model_max_length
  26. stride: chunk size for splitting data for QA tasks
  27. task: optional model task or category, determines the model type, defaults to text-classification
  28. prefix: optional source prefix
  29. metrics: optional function that computes and returns a dict of evaluation metrics
  30. tokenizers: optional number of concurrent tokenizers, defaults to None
  31. checkpoint: optional resume from checkpoint flag or path to checkpoint directory, defaults to None
  32. quantize: quantization configuration to pass to base model
  33. lora: lora configuration to pass to PEFT model
  34. args: training arguments
  35. Returns:
  36. (model, tokenizer)
  37. “””
  38. # Quantization / LoRA support
  39. if (quantize or lora) and not PEFT:
  40. raise ImportError(‘PEFT is not available - install pipeline extra to enable’)
  41. # Parse TrainingArguments
  42. args = self.parse(args)
  43. # Set seed for model reproducibility
  44. set_seed(args.seed)
  45. # Load model configuration, tokenizer and max sequence length
  46. config, tokenizer, maxlength = self.load(base, maxlength)
  47. # Default tokenizer pad token if it’s not set
  48. tokenizer.pad_token = tokenizer.pad_token if tokenizer.pad_token is not None else tokenizer.eos_token
  49. # Prepare parameters
  50. process, collator, labels = self.prepare(task, train, tokenizer, columns, maxlength, stride, prefix, args)
  51. # Tokenize training and validation data
  52. train, validation = process(train, validation, os.cpu_count() if tokenizers and isinstance(tokenizers, bool) else tokenizers)
  53. # Create model to train
  54. model = self.model(task, base, config, labels, tokenizer, quantize)
  55. # Default config pad token if it’s not set
  56. model.config.pad_token_id = model.config.pad_token_id if model.config.pad_token_id is not None else model.config.eos_token_id
  57. # Load as PEFT model, if necessary
  58. model = self.peft(task, lora, model)
  59. # Add model to collator
  60. if collator:
  61. collator.model = model
  62. # Build trainer
  63. trainer = Trainer(
  64. model=model,
  65. tokenizer=tokenizer,
  66. data_collator=collator,
  67. args=args,
  68. train_dataset=train,
  69. eval_dataset=validation if validation else None,
  70. compute_metrics=metrics,
  71. )
  72. # Run training
  73. trainer.train(resume_from_checkpoint=checkpoint)
  74. # Run evaluation
  75. if validation:
  76. trainer.evaluate()
  77. # Save model outputs
  78. if args.should_save:
  79. trainer.save_model()
  80. trainer.save_state()
  81. # Put model in eval mode to disable weight updates and return (model, tokenizer)
  82. return (model.eval(), tokenizer)