Module ktrain.text.zsl.core
Expand source code
import math
import warnings
import numpy as np
from ... import utils as U
from ...torch_base import TorchBase
list2chunks = U.list2chunks
class ZeroShotClassifier(TorchBase):
"""
interface to Zero Shot Topic Classifier
"""
def __init__(
self, model_name="facebook/bart-large-mnli", device=None, quantize=False
):
"""
```
ZeroShotClassifier constructor
Args:
model_name(str): name of a BART NLI model
device(str): device to use (e.g., 'cuda', 'cpu')
quantize(bool): If True, faster quantization will be used
```
"""
if "mnli" not in model_name and "xnli" not in model_name:
raise ValueError("ZeroShotClasifier requires an MNLI or XNLI model")
super().__init__(device=device, quantize=quantize)
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
self.torch_device
)
if quantize:
self.model = self.quantize_model(self.model)
def predict(
self,
docs,
labels=[],
include_labels=False,
multilabel=True,
max_length=512,
batch_size=8,
nli_template="This text is about {}.",
topic_strings=[],
):
"""
```
This method performs zero-shot text classification using Natural Language Inference (NLI).
Args:
docs(list|str): text of document or list of texts
labels(list): a list of strings representing topics of your choice
Example:
labels=['political science', 'sports', 'science']
include_labels(bool): If True, will return topic labels along with topic probabilities
multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1.
If False, scores are normalized such that probabilities sum to 1.
max_length(int): truncate long documents to this many tokens
batch_size(int): batch_size to use. default:8
Increase this value to speed up predictions - especially
if len(topic_strings) is large.
nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference
topic_strings(list): alias for labels parameter for backwards compatibility
Returns:
inferred probabilities or list of inferred probabilities if doc is list
```
"""
# error checks
is_str_input = False
if not isinstance(docs, (list, np.ndarray)):
docs = [docs]
is_str_input = True
if not isinstance(docs[0], str):
raise ValueError(
"docs must be string or a list of strings representing document(s)"
)
if len(labels) > 0 and len(topic_strings) > 0:
raise ValueError("labels and topic_strings are mutually exclusive")
if not labels and not topic_strings:
raise ValueError("labels must be a list of strings")
if topic_strings:
labels = topic_strings
# convert to sequences
sequence_pairs = []
for premise in docs:
sequence_pairs.extend(
[[premise, nli_template.format(label)] for label in labels]
)
if batch_size > len(sequence_pairs):
batch_size = len(sequence_pairs)
if len(sequence_pairs) >= 100 and batch_size == 8:
warnings.warn(
"TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions"
)
num_chunks = math.ceil(len(sequence_pairs) / batch_size)
sequence_chunks = list2chunks(sequence_pairs, n=num_chunks)
# inference
import torch
with torch.no_grad():
outputs = []
for sequences in sequence_chunks:
batch = self.tokenizer.batch_encode_plus(
sequences,
return_tensors="pt",
max_length=max_length,
truncation="only_first",
padding=True,
).to(self.torch_device)
logits = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
return_dict=False,
)[0]
outputs.extend(logits.cpu().detach().numpy())
# entail_contradiction_logits = logits[:,[0,2]]
# probs = entail_contradiction_logits.softmax(dim=1)
# true_probs = list(probs[:,1].cpu().detach().numpy())
# result.extend(true_probs)
outputs = np.array(outputs)
outputs = outputs.reshape((len(docs), len(labels), -1))
# process outputs
# 2020-08-24: modified based on transformers pipeline implementation
if multilabel:
# softmax over the entailment vs. contradiction dim for each label independently
entail_contr_logits = outputs[..., [0, -1]]
scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(
-1, keepdims=True
)
scores = scores[..., 1]
else:
# softmax the "entailment" logits over all candidate labels
entail_logits = outputs[..., -1]
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(
-1, keepdims=True
)
scores = scores.tolist()
if include_labels:
scores = [list(zip(labels, s)) for s in scores]
if is_str_input:
scores = scores[0]
return scores
Classes
class ZeroShotClassifier (model_name='facebook/bart-large-mnli', device=None, quantize=False)
-
interface to Zero Shot Topic Classifier
ZeroShotClassifier constructor Args: model_name(str): name of a BART NLI model device(str): device to use (e.g., 'cuda', 'cpu') quantize(bool): If True, faster quantization will be used
Expand source code
class ZeroShotClassifier(TorchBase): """ interface to Zero Shot Topic Classifier """ def __init__( self, model_name="facebook/bart-large-mnli", device=None, quantize=False ): """ ``` ZeroShotClassifier constructor Args: model_name(str): name of a BART NLI model device(str): device to use (e.g., 'cuda', 'cpu') quantize(bool): If True, faster quantization will be used ``` """ if "mnli" not in model_name and "xnli" not in model_name: raise ValueError("ZeroShotClasifier requires an MNLI or XNLI model") super().__init__(device=device, quantize=quantize) from transformers import AutoModelForSequenceClassification, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to( self.torch_device ) if quantize: self.model = self.quantize_model(self.model) def predict( self, docs, labels=[], include_labels=False, multilabel=True, max_length=512, batch_size=8, nli_template="This text is about {}.", topic_strings=[], ): """ ``` This method performs zero-shot text classification using Natural Language Inference (NLI). Args: docs(list|str): text of document or list of texts labels(list): a list of strings representing topics of your choice Example: labels=['political science', 'sports', 'science'] include_labels(bool): If True, will return topic labels along with topic probabilities multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1. If False, scores are normalized such that probabilities sum to 1. max_length(int): truncate long documents to this many tokens batch_size(int): batch_size to use. default:8 Increase this value to speed up predictions - especially if len(topic_strings) is large. nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference topic_strings(list): alias for labels parameter for backwards compatibility Returns: inferred probabilities or list of inferred probabilities if doc is list ``` """ # error checks is_str_input = False if not isinstance(docs, (list, np.ndarray)): docs = [docs] is_str_input = True if not isinstance(docs[0], str): raise ValueError( "docs must be string or a list of strings representing document(s)" ) if len(labels) > 0 and len(topic_strings) > 0: raise ValueError("labels and topic_strings are mutually exclusive") if not labels and not topic_strings: raise ValueError("labels must be a list of strings") if topic_strings: labels = topic_strings # convert to sequences sequence_pairs = [] for premise in docs: sequence_pairs.extend( [[premise, nli_template.format(label)] for label in labels] ) if batch_size > len(sequence_pairs): batch_size = len(sequence_pairs) if len(sequence_pairs) >= 100 and batch_size == 8: warnings.warn( "TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions" ) num_chunks = math.ceil(len(sequence_pairs) / batch_size) sequence_chunks = list2chunks(sequence_pairs, n=num_chunks) # inference import torch with torch.no_grad(): outputs = [] for sequences in sequence_chunks: batch = self.tokenizer.batch_encode_plus( sequences, return_tensors="pt", max_length=max_length, truncation="only_first", padding=True, ).to(self.torch_device) logits = self.model( batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=False, )[0] outputs.extend(logits.cpu().detach().numpy()) # entail_contradiction_logits = logits[:,[0,2]] # probs = entail_contradiction_logits.softmax(dim=1) # true_probs = list(probs[:,1].cpu().detach().numpy()) # result.extend(true_probs) outputs = np.array(outputs) outputs = outputs.reshape((len(docs), len(labels), -1)) # process outputs # 2020-08-24: modified based on transformers pipeline implementation if multilabel: # softmax over the entailment vs. contradiction dim for each label independently entail_contr_logits = outputs[..., [0, -1]] scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum( -1, keepdims=True ) scores = scores[..., 1] else: # softmax the "entailment" logits over all candidate labels entail_logits = outputs[..., -1] scores = np.exp(entail_logits) / np.exp(entail_logits).sum( -1, keepdims=True ) scores = scores.tolist() if include_labels: scores = [list(zip(labels, s)) for s in scores] if is_str_input: scores = scores[0] return scores
Ancestors
Methods
def predict(self, docs, labels=[], include_labels=False, multilabel=True, max_length=512, batch_size=8, nli_template='This text is about {}.', topic_strings=[])
-
This method performs zero-shot text classification using Natural Language Inference (NLI). Args: docs(list|str): text of document or list of texts labels(list): a list of strings representing topics of your choice Example: labels=['political science', 'sports', 'science'] include_labels(bool): If True, will return topic labels along with topic probabilities multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1. If False, scores are normalized such that probabilities sum to 1. max_length(int): truncate long documents to this many tokens batch_size(int): batch_size to use. default:8 Increase this value to speed up predictions - especially if len(topic_strings) is large. nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference topic_strings(list): alias for labels parameter for backwards compatibility Returns: inferred probabilities or list of inferred probabilities if doc is list
Expand source code
def predict( self, docs, labels=[], include_labels=False, multilabel=True, max_length=512, batch_size=8, nli_template="This text is about {}.", topic_strings=[], ): """ ``` This method performs zero-shot text classification using Natural Language Inference (NLI). Args: docs(list|str): text of document or list of texts labels(list): a list of strings representing topics of your choice Example: labels=['political science', 'sports', 'science'] include_labels(bool): If True, will return topic labels along with topic probabilities multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1. If False, scores are normalized such that probabilities sum to 1. max_length(int): truncate long documents to this many tokens batch_size(int): batch_size to use. default:8 Increase this value to speed up predictions - especially if len(topic_strings) is large. nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference topic_strings(list): alias for labels parameter for backwards compatibility Returns: inferred probabilities or list of inferred probabilities if doc is list ``` """ # error checks is_str_input = False if not isinstance(docs, (list, np.ndarray)): docs = [docs] is_str_input = True if not isinstance(docs[0], str): raise ValueError( "docs must be string or a list of strings representing document(s)" ) if len(labels) > 0 and len(topic_strings) > 0: raise ValueError("labels and topic_strings are mutually exclusive") if not labels and not topic_strings: raise ValueError("labels must be a list of strings") if topic_strings: labels = topic_strings # convert to sequences sequence_pairs = [] for premise in docs: sequence_pairs.extend( [[premise, nli_template.format(label)] for label in labels] ) if batch_size > len(sequence_pairs): batch_size = len(sequence_pairs) if len(sequence_pairs) >= 100 and batch_size == 8: warnings.warn( "TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions" ) num_chunks = math.ceil(len(sequence_pairs) / batch_size) sequence_chunks = list2chunks(sequence_pairs, n=num_chunks) # inference import torch with torch.no_grad(): outputs = [] for sequences in sequence_chunks: batch = self.tokenizer.batch_encode_plus( sequences, return_tensors="pt", max_length=max_length, truncation="only_first", padding=True, ).to(self.torch_device) logits = self.model( batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=False, )[0] outputs.extend(logits.cpu().detach().numpy()) # entail_contradiction_logits = logits[:,[0,2]] # probs = entail_contradiction_logits.softmax(dim=1) # true_probs = list(probs[:,1].cpu().detach().numpy()) # result.extend(true_probs) outputs = np.array(outputs) outputs = outputs.reshape((len(docs), len(labels), -1)) # process outputs # 2020-08-24: modified based on transformers pipeline implementation if multilabel: # softmax over the entailment vs. contradiction dim for each label independently entail_contr_logits = outputs[..., [0, -1]] scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum( -1, keepdims=True ) scores = scores[..., 1] else: # softmax the "entailment" logits over all candidate labels entail_logits = outputs[..., -1] scores = np.exp(entail_logits) / np.exp(entail_logits).sum( -1, keepdims=True ) scores = scores.tolist() if include_labels: scores = [list(zip(labels, s)) for s in scores] if is_str_input: scores = scores[0] return scores
Inherited members