= FewShotClassifier() clf
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
FewShotClassifier (model_id_or_path:str='sentence- transformers/paraphrase-mpnet-base-v2', use_smaller:bool=False, **kwargs)
Helper class that provides a standard way to create an ABC using inheritance.
ClassifierBase ()
Helper class that provides a standard way to create an ABC using inheritance.
ClassifierBase.arrays2dataset (X:List[str], y:Union[List[int],List[str]], text_key:str='text', label_key:str='label')
Convert train or test examples to HF dataset
ClassifierBase.dataset2arrays (dataset, text_key:str='text', label_key:str='label')
Convert a Hugging Face dataset to X, y arrays
ClassifierBase.evaluate (X_eval:list, y_eval:list, print_report:bool=False, labels:List[str]=[])
Evaluates labeled data using the trained model. If print_report
is True, prints classification report and returns nothing. Otherwise, returns and prints a dictionary of the results.
ClassifierBase.explain (X:list, labels:List[str]=[])
Explain the predictions on given examples in X
. (Requires shap
and matplotlib
to be installed.)
ClassifierBase.get_trainer ()
Retrieves last trainer
ClassifierBase.sample_examples (X:list, y:list, num_samples:int=8, text_key:str='text', label_key:str='label')
Sample a dataset with num_samples
per class
FewShotClassifier.save (save_path:str)
Save model to specified folder path, save_path
FewShotClassifier.train (X:List[str], y:Union[List[int],List[str]], num_epochs:int=10, batch_size:int=32, metric='accuracy', callbacks=None, **kwargs)
*Trains the classifier on a list of texts (X
) and a list of labels (y
). Additional keyword arguments are passed directly to SetFit.TrainingArguments
Args:
Returns:
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
dataset = load_dataset("SetFit/sst2")
X_train, y_train = clf.dataset2arrays(dataset["train"], text_key="text", label_key="label")
X_test, y_test = clf.dataset2arrays(dataset["test"], text_key="text", label_key="label")
X_sample, y_sample = clf.sample_examples(X_train, y_train, label_key="label", num_samples=8)
/home/amaiya/mambaforge/envs/llm/lib/python3.9/site-packages/huggingface_hub/repocard.py:105: UserWarning: Repo card metadata block was not found. Setting CardData to empty.
warnings.warn("Repo card metadata block was not found. Setting CardData to empty.")
Applying column mapping to the training dataset
***** Running training *****
Num unique pairs = 144
Batch size = 32
Num epochs = 10
Total optimization steps = 50
Step | Training Loss |
---|
accuracy: 0.9060955518945635
macro avg:
f1-score: 0.9060502208366357
precision: 0.9067880482593942
recall: 0.9060618232875919
support: 1821.0
negative:
f1-score: 0.9081139172487909
precision: 0.8904109589041096
recall: 0.9265350877192983
support: 912.0
positive:
f1-score: 0.9039865244244806
precision: 0.9231651376146789
recall: 0.8855885588558856
support: 909.0
weighted avg:
f1-score: 0.9060536206659804
precision: 0.9067610678815438
recall: 0.9060955518945635
support: 1821.0
tensor([[0.1028, 0.8972],
[0.9357, 0.0643]], dtype=torch.float64)