Few-Shot Classification

The FewShotClassifier in OnPrem.LLM is a simple wrapper around the SetFit package that allows you to make text classification predictions on only a few labeled examples (e.g., 8 examples per class). It is useful when only a small amount of labeled examples are available for training a model. We will supply the use_smaller=True argument to use the smaller version of the default model. You can also supply the Hugging Face model ID of an embedding model of your own choosing.

from onprem.pipelines import FewShotClassifier
clf = FewShotClassifier(use_smaller=True)
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.

The default model is sentence-transformers/paraphrase-mpnet-base-v2, but we’re using all-MiniLM-L6-v2 in this example:

clf.model_id_or_path
'sentence-transformers/all-MiniLM-L6-v2'

STEP 1: Construct a Tiny Dataset

In this example, we will classify a sample of the 20NewsGroup dataset.

categories = ['comp.graphics', 'sci.med', 'sci.space', 'soc.religion.christian']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train',
   categories=categories, shuffle=True, random_state=42)
test_b = fetch_20newsgroups(subset='test',
   categories=categories, shuffle=True, random_state=42)

X_train = train_b.data
y_train = [train_b.target_names[y] for y in train_b.target]
X_test = test_b.data
y_test = [test_b.target_names[y] for y in test_b.target]

# sample a small number of examples from full training set
X_sample, y_sample = clf.sample_examples(X_train, y_train, num_samples=8)

There are only 32 training examples!

len(X_sample)
32

There are 1502 test examples.

len(X_test)
1577

STEP 2: Train on the Tiny Dataset

Let’s train:

clf.train(X_sample,  y_sample, max_steps=50)
Applying column mapping to the training dataset
***** Running training *****
  Num unique pairs = 768
  Batch size = 32
  Num epochs = 10
  Total optimization steps = 50
[50/50 00:13, Epoch 2/0]
Step Training Loss

STEP 3: Evaluate

After training, model.labels stores the labels.

print(clf.evaluate(X_test, y_test, print_report=True))
                        precision    recall  f1-score   support

         comp.graphics       0.93      0.90      0.92       389
               sci.med       0.94      0.89      0.91       396
             sci.space       0.88      0.90      0.89       394
soc.religion.christian       0.92      0.97      0.94       398

              accuracy                           0.92      1577
             macro avg       0.92      0.92      0.92      1577
          weighted avg       0.92      0.92      0.92      1577

A 92% accuracy using only 32 examples!

STEP 4: Make Predictions on New Data

Let’s make some predictions on new data:

clf.predict(['My grapics card sucks for machine learning.'])
array(['comp.graphics'], dtype='<U22')
clf.predict(['My mom likes going to church.'])
array(['soc.religion.christian'], dtype='<U22')
clf.predict(['SpaceX launches lots of satellites.'])
array(['sci.space'], dtype='<U22')

Show prediction probabilities:

clf.predict_proba(['SpaceX launches lots of satellites.'])
tensor([[0.2190, 0.1312, 0.5243, 0.1256]], dtype=torch.float64)

STEP 5: Inspect and Explain Predictions

Explain predictions:

clf.explain(['My graphics card sucks for machine learning.'])


[0]
outputs
comp.graphics
sci.med
sci.space
soc.religion.christian


0.30.20.40.50.2968770.296877base value0.5290320.529032fcomp.graphics(inputs)0.121 graphics 0.037 machine 0.029 card 0.026 for 0.017 learning 0.016 My 0.0 -0.009 sucks -0.005 . -0.0
inputs
0.0
0.016
My
0.121
graphics
0.029
card
-0.009
sucks
0.026
for
0.037
machine
0.017
learning
-0.005
.
-0.0

If you click on comp.graphics, you can see that the word “graphics” has the largest impact on the prediction after “card” and then “machine”.

STEP 6: Save and/or Reload the Model

Save and reload the model:

clf.save('/tmp/my_fewshot_model')
clf = FewShotClassifier('/tmp/my_fewshot_model')
clf.predict(['Elon Musk likes launching satellites.'])
array(['sci.space'], dtype='<U22')