Module ktrain.text.summarization.core
Expand source code
from ...torch_base import TorchBase
class TransformerSummarizer(TorchBase):
"""
interface to Transformer-based text summarization
"""
def __init__(self, model_name="facebook/bart-large-cnn", device=None):
"""
```
interface to BART-based text summarization using transformers library
Args:
model_name(str): name of BART model for summarization
device(str): device to use (e.g., 'cuda', 'cpu')
```
"""
if "bart" not in model_name:
raise ValueError("TransformerSummarizer currently only accepts BART models")
super().__init__(device=device)
from transformers import BartForConditionalGeneration, BartTokenizer
self.tokenizer = BartTokenizer.from_pretrained(model_name)
self.model = BartForConditionalGeneration.from_pretrained(model_name).to(
self.torch_device
)
def summarize(
self,
doc,
max_length=150,
min_length=56,
no_repeat_ngram_size=3,
length_penalty=2.0,
num_beams=4,
**kwargs,
):
"""
```
Summarize document text. Extra arguments are fed to generate method
Args:
doc(str): text of document
Returns:
str: summary text
```
"""
import torch
with torch.no_grad():
answers_input_ids = self.tokenizer.batch_encode_plus(
[doc], return_tensors="pt", truncation=True, max_length=1024
)["input_ids"].to(self.torch_device)
summary_ids = self.model.generate(
answers_input_ids,
num_beams=num_beams,
length_penalty=length_penalty,
max_length=max_length,
min_length=min_length,
no_repeat_ngram_size=no_repeat_ngram_size,
**kwargs,
)
exec_sum = self.tokenizer.decode(
summary_ids.squeeze(), skip_special_tokens=True
)
return exec_sum
class LexRankSummarizer:
"""
interface to Lexrank-based text summarization
"""
def __init__(self, language="english"):
"""
```
interface to Lexrank-based text summarization using sumy library
Args:
language(str): default is "english"
```
"""
try:
from sumy.nlp.stemmers import Stemmer
from sumy.summarizers.lex_rank import LexRankSummarizer as Summarizer
from sumy.utils import get_stop_words
except ImportError:
raise ImportError("Please install the sumy package: pip install sumy")
self.language = language
stemmer = Stemmer(self.language)
self.summarizer = Summarizer(stemmer)
self.summarizer.stop_words = get_stop_words(self.language)
def summarize(
self,
doc,
num_sentences=3,
maximum_length=2000,
minimum_length=40,
join_sentences=True,
num_candidate_sentences=100,
**kwargs,
):
"""
```
summarize document text
Args:
doc(str): text of document
num_sentences(int): Number of sentences for summary
maximum_length(int): Maximum length of sentence in summary
minimumlength(int): Minimum length of sentence in summary
join_sentences(bool): If True, summary is a single string instead of a list of sentences.
num_candidate_sentences(int): Number of candidate sentences from which to select final summary.
Returns:
str: summary text
```
"""
from sumy.nlp.tokenizers import Tokenizer
from sumy.parsers.html import HtmlParser
from sumy.parsers.plaintext import PlaintextParser
parser = PlaintextParser.from_string(doc, Tokenizer(self.language))
results = []
for sentence in self.summarizer(parser.document, num_candidate_sentences):
if (
len(sentence._text) > maximum_length
or len(sentence._text) < minimum_length
or sentence._text[0].isdigit()
):
continue
results.append(
sentence._text + "."
if sentence._text[-1] not in [".", "?", "!", ";"]
else sentence._text
)
return (
" ".join(results[:num_sentences])
if join_sentences
else results[:num_sentences]
)
Classes
class LexRankSummarizer (language='english')
-
interface to Lexrank-based text summarization
interface to Lexrank-based text summarization using sumy library Args: language(str): default is "english"
Expand source code
class LexRankSummarizer: """ interface to Lexrank-based text summarization """ def __init__(self, language="english"): """ ``` interface to Lexrank-based text summarization using sumy library Args: language(str): default is "english" ``` """ try: from sumy.nlp.stemmers import Stemmer from sumy.summarizers.lex_rank import LexRankSummarizer as Summarizer from sumy.utils import get_stop_words except ImportError: raise ImportError("Please install the sumy package: pip install sumy") self.language = language stemmer = Stemmer(self.language) self.summarizer = Summarizer(stemmer) self.summarizer.stop_words = get_stop_words(self.language) def summarize( self, doc, num_sentences=3, maximum_length=2000, minimum_length=40, join_sentences=True, num_candidate_sentences=100, **kwargs, ): """ ``` summarize document text Args: doc(str): text of document num_sentences(int): Number of sentences for summary maximum_length(int): Maximum length of sentence in summary minimumlength(int): Minimum length of sentence in summary join_sentences(bool): If True, summary is a single string instead of a list of sentences. num_candidate_sentences(int): Number of candidate sentences from which to select final summary. Returns: str: summary text ``` """ from sumy.nlp.tokenizers import Tokenizer from sumy.parsers.html import HtmlParser from sumy.parsers.plaintext import PlaintextParser parser = PlaintextParser.from_string(doc, Tokenizer(self.language)) results = [] for sentence in self.summarizer(parser.document, num_candidate_sentences): if ( len(sentence._text) > maximum_length or len(sentence._text) < minimum_length or sentence._text[0].isdigit() ): continue results.append( sentence._text + "." if sentence._text[-1] not in [".", "?", "!", ";"] else sentence._text ) return ( " ".join(results[:num_sentences]) if join_sentences else results[:num_sentences] )
Methods
def summarize(self, doc, num_sentences=3, maximum_length=2000, minimum_length=40, join_sentences=True, num_candidate_sentences=100, **kwargs)
-
summarize document text Args: doc(str): text of document num_sentences(int): Number of sentences for summary maximum_length(int): Maximum length of sentence in summary minimumlength(int): Minimum length of sentence in summary join_sentences(bool): If True, summary is a single string instead of a list of sentences. num_candidate_sentences(int): Number of candidate sentences from which to select final summary. Returns: str: summary text
Expand source code
def summarize( self, doc, num_sentences=3, maximum_length=2000, minimum_length=40, join_sentences=True, num_candidate_sentences=100, **kwargs, ): """ ``` summarize document text Args: doc(str): text of document num_sentences(int): Number of sentences for summary maximum_length(int): Maximum length of sentence in summary minimumlength(int): Minimum length of sentence in summary join_sentences(bool): If True, summary is a single string instead of a list of sentences. num_candidate_sentences(int): Number of candidate sentences from which to select final summary. Returns: str: summary text ``` """ from sumy.nlp.tokenizers import Tokenizer from sumy.parsers.html import HtmlParser from sumy.parsers.plaintext import PlaintextParser parser = PlaintextParser.from_string(doc, Tokenizer(self.language)) results = [] for sentence in self.summarizer(parser.document, num_candidate_sentences): if ( len(sentence._text) > maximum_length or len(sentence._text) < minimum_length or sentence._text[0].isdigit() ): continue results.append( sentence._text + "." if sentence._text[-1] not in [".", "?", "!", ";"] else sentence._text ) return ( " ".join(results[:num_sentences]) if join_sentences else results[:num_sentences] )
class TransformerSummarizer (model_name='facebook/bart-large-cnn', device=None)
-
interface to Transformer-based text summarization
interface to BART-based text summarization using transformers library Args: model_name(str): name of BART model for summarization device(str): device to use (e.g., 'cuda', 'cpu')
Expand source code
class TransformerSummarizer(TorchBase): """ interface to Transformer-based text summarization """ def __init__(self, model_name="facebook/bart-large-cnn", device=None): """ ``` interface to BART-based text summarization using transformers library Args: model_name(str): name of BART model for summarization device(str): device to use (e.g., 'cuda', 'cpu') ``` """ if "bart" not in model_name: raise ValueError("TransformerSummarizer currently only accepts BART models") super().__init__(device=device) from transformers import BartForConditionalGeneration, BartTokenizer self.tokenizer = BartTokenizer.from_pretrained(model_name) self.model = BartForConditionalGeneration.from_pretrained(model_name).to( self.torch_device ) def summarize( self, doc, max_length=150, min_length=56, no_repeat_ngram_size=3, length_penalty=2.0, num_beams=4, **kwargs, ): """ ``` Summarize document text. Extra arguments are fed to generate method Args: doc(str): text of document Returns: str: summary text ``` """ import torch with torch.no_grad(): answers_input_ids = self.tokenizer.batch_encode_plus( [doc], return_tensors="pt", truncation=True, max_length=1024 )["input_ids"].to(self.torch_device) summary_ids = self.model.generate( answers_input_ids, num_beams=num_beams, length_penalty=length_penalty, max_length=max_length, min_length=min_length, no_repeat_ngram_size=no_repeat_ngram_size, **kwargs, ) exec_sum = self.tokenizer.decode( summary_ids.squeeze(), skip_special_tokens=True ) return exec_sum
Ancestors
Methods
def summarize(self, doc, max_length=150, min_length=56, no_repeat_ngram_size=3, length_penalty=2.0, num_beams=4, **kwargs)
-
Summarize document text. Extra arguments are fed to generate method Args: doc(str): text of document Returns: str: summary text
Expand source code
def summarize( self, doc, max_length=150, min_length=56, no_repeat_ngram_size=3, length_penalty=2.0, num_beams=4, **kwargs, ): """ ``` Summarize document text. Extra arguments are fed to generate method Args: doc(str): text of document Returns: str: summary text ``` """ import torch with torch.no_grad(): answers_input_ids = self.tokenizer.batch_encode_plus( [doc], return_tensors="pt", truncation=True, max_length=1024 )["input_ids"].to(self.torch_device) summary_ids = self.model.generate( answers_input_ids, num_beams=num_beams, length_penalty=length_penalty, max_length=max_length, min_length=min_length, no_repeat_ngram_size=no_repeat_ngram_size, **kwargs, ) exec_sum = self.tokenizer.decode( summary_ids.squeeze(), skip_special_tokens=True ) return exec_sum
Inherited members