Few-Shot Classification

The FewShotClassifier in OnPrem.LLM is a simple wrapper around the SetFit package that allows you to make text classification predictions on only a few labeled examples (e.g., 8 examples per class). It is useful when only a small amount of labeled examples are available for training a model. We will supply the use_smaller=True argument to use the smaller version of the default model. You can also supply the Hugging Face model ID of an embedding model of your own choosing.

from onprem.pipelines import FewShotClassifier
clf = FewShotClassifier(use_smaller=True)
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.

The default model is sentence-transformers/paraphrase-mpnet-base-v2, but we’re using all-MiniLM-L6-v2 in this example:

clf.model_id_or_path
'sentence-transformers/all-MiniLM-L6-v2'

STEP 1: Construct a Tiny Dataset

In this example, we will classify a sample of the 20NewsGroup dataset.

categories = ['comp.graphics', 'sci.med', 'sci.space', 'soc.religion.christian']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train',
   categories=categories, shuffle=True, random_state=42)
test_b = fetch_20newsgroups(subset='test',
   categories=categories, shuffle=True, random_state=42)

X_train = train_b.data
y_train = [train_b.target_names[y] for y in train_b.target]
X_test = test_b.data
y_test = [test_b.target_names[y] for y in test_b.target]

# sample a small number of examples from full training set
X_sample, y_sample = clf.sample_examples(X_train, y_train, num_samples=8)

There are only 32 training examples!

len(X_sample)
32

There are 1502 test examples.

len(X_test)
1577

STEP 2: Train on the Tiny Dataset

Let’s train:

clf.train(X_sample,  y_sample, max_steps=50)
Applying column mapping to the training dataset
***** Running training *****
  Num unique pairs = 768
  Batch size = 32
  Num epochs = 10
[50/50 00:30, Epoch 2/3]
Step Training Loss
1 0.300400
50 0.110400

Note: If you encounter an error above, downgrading your version of transformers can sometimes resolve it (e.g., see this issue, which was open at the time of this writing).

STEP 3: Evaluate

After training, clf.model.labels stores the labels.

print(clf.evaluate(X_test, y_test, print_report=True))
                        precision    recall  f1-score   support

         comp.graphics       0.93      0.93      0.93       389
               sci.med       0.93      0.88      0.91       396
             sci.space       0.89      0.93      0.91       394
soc.religion.christian       0.94      0.96      0.95       398

              accuracy                           0.92      1577
             macro avg       0.92      0.92      0.92      1577
          weighted avg       0.92      0.92      0.92      1577

None

A 92% accuracy using only 32 examples!

STEP 4: Make Predictions on New Data

Let’s make some predictions on new data:

clf.predict(['My grapics card sucks for machine learning.'])
array(['comp.graphics'], dtype='<U22')
clf.predict(['My mom likes going to church.'])
array(['soc.religion.christian'], dtype='<U22')
clf.predict(['SpaceX launches lots of satellites.'])
array(['sci.space'], dtype='<U22')

Show prediction probabilities:

clf.predict_proba(['SpaceX launches lots of satellites.'])
tensor([[0.3000, 0.1330, 0.4385, 0.1285]], dtype=torch.float64)

STEP 5: Inspect and Explain Predictions

Explain predictions:

clf.explain(['My graphics card sucks for machine learning.'])


[0]
outputs
comp.graphics
sci.med
sci.space
soc.religion.christian


0.30.20.40.50.2980550.298055base value0.4888920.488892fcomp.graphics(inputs)0.097 graphics 0.043 machine 0.032 card 0.024 for 0.015 My 0.007 . 0.001 learning 0.0 -0.026 sucks -0.0
inputs
-0.0
0.015
My
0.097
graphics
0.032
card
-0.026
sucks
0.024
for
0.043
machine
0.001
learning
0.007
.
0.0

If you click on comp.graphics, you can see that the word “graphics” has the largest impact on the prediction after “card” and then “machine”.

STEP 6: Save and/or Reload the Model

Save and reload the model:

clf.save('/tmp/my_fewshot_model')
clf = FewShotClassifier('/tmp/my_fewshot_model')
clf.predict(['Elon Musk likes launching satellites.'])
array(['sci.space'], dtype='<U22')

TIPS: You can also easily train traditional text classification models using Hugging Face Transformers and scikit-learn. Please see the documentation for more details.