The FewShotClassifier in OnPrem.LLM is a simple wrapper around the SetFit package that allows you to make text classification predictions on only a few labeled examples (e.g., 8 examples per class). It is useful when only a small amount of labeled examples are available for training a model. We will supply the use_smaller=True argument to use the smaller version of the default model. You can also supply the Hugging Face model ID of an embedding model of your own choosing.
from onprem.pipelines import FewShotClassifierclf = FewShotClassifier(use_smaller=True)
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
The default model is sentence-transformers/paraphrase-mpnet-base-v2, but we’re using all-MiniLM-L6-v2 in this example:
clf.model_id_or_path
'sentence-transformers/all-MiniLM-L6-v2'
STEP 1: Construct a Tiny Dataset
In this example, we will classify a sample of the 20NewsGroup dataset.
categories = ['comp.graphics', 'sci.med', 'sci.space', 'soc.religion.christian']from sklearn.datasets import fetch_20newsgroupstrain_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)test_b = fetch_20newsgroups(subset='test', categories=categories, shuffle=True, random_state=42)X_train = train_b.datay_train = [train_b.target_names[y] for y in train_b.target]X_test = test_b.datay_test = [test_b.target_names[y] for y in test_b.target]# sample a small number of examples from full training setX_sample, y_sample = clf.sample_examples(X_train, y_train, num_samples=8)
There are only 32 training examples!
len(X_sample)
32
There are 1502 test examples.
len(X_test)
1577
STEP 2: Train on the Tiny Dataset
Let’s train:
clf.train(X_sample, y_sample, max_steps=50)
Applying column mapping to the training dataset
***** Running training *****
Num unique pairs = 768
Batch size = 32
Num epochs = 10
Total optimization steps = 50
[50/50 00:13, Epoch 2/0]
Step
Training Loss
Note: If you encounter an error above, downgrading your version of transformers can sometimes resolve it (e.g., see this issue, which was open at the time of this writing).