Module ktrain.text.learner

Expand source code
from .. import utils as U
from ..core import ArrayLearner, GenLearner, _load_model
from ..imports import *
from .preprocessor import TransformersPreprocessor


class BERTTextClassLearner(ArrayLearner):
    """
    ```
    Main class used to tune and train Keras models for text classification using Array data.
    ```
    """

    def __init__(
        self,
        model,
        train_data=None,
        val_data=None,
        batch_size=U.DEFAULT_BS,
        eval_batch_size=U.DEFAULT_BS,
        workers=1,
        use_multiprocessing=False,
    ):
        super().__init__(
            model,
            train_data=train_data,
            val_data=val_data,
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )
        return

    def view_top_losses(self, n=4, preproc=None, val_data=None):
        """
        ```
        Views observations with top losses in validation set.
        Args:
         n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
         preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                                 For some data like text data, a preprocessor
                                 is required to undo the pre-processing
                                 to correctly view raw data.
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.
        ```
        """
        val = self._check_val(val_data)

        # get top losses and associated data
        tups = self.top_losses(n=n, val_data=val, preproc=preproc)

        # get multilabel status and class names
        classes = preproc.get_classes() if preproc is not None else None

        # iterate through losses
        for tup in tups:
            # get data
            idx = tup[0]
            loss = tup[1]
            truth = tup[2]
            pred = tup[3]

            # BERT-style tuple
            join_char = " "
            obs = val[0][0][idx]
            if preproc is not None:
                obs = preproc.undo(obs)
                if preproc.is_nospace_lang():
                    join_char = ""
            if type(obs) == str:
                obs = join_char.join(obs.split()[:512])
            print("----------")
            print(
                "id:%s | loss:%s | true:%s | pred:%s)\n"
                % (idx, round(loss, 2), truth, pred)
            )
            print(obs)
        return


class TransformerTextClassLearner(GenLearner):
    """
    ```
    Main class used to tune and train Keras models for text classification using Array data.
    ```
    """

    def __init__(
        self,
        model,
        train_data=None,
        val_data=None,
        batch_size=U.DEFAULT_BS,
        eval_batch_size=U.DEFAULT_BS,
        workers=1,
        use_multiprocessing=False,
    ):
        super().__init__(
            model,
            train_data=train_data,
            val_data=val_data,
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )
        return

    def view_top_losses(self, n=4, preproc=None, val_data=None):
        """
        ```
        Views observations with top losses in validation set.
        Args:
         n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
         preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                                 For some data like text data, a preprocessor
                                 is required to undo the pre-processing
                                 to correctly view raw data.
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.
        ```
        """
        val = self._check_val(val_data)

        # get top losses and associated data
        tups = self.top_losses(n=n, val_data=val, preproc=preproc)

        # get multilabel status and class names
        classes = preproc.get_classes() if preproc is not None else None

        # iterate through losses
        for tup in tups:
            # get data
            idx = tup[0]
            loss = tup[1]
            truth = tup[2]
            pred = tup[3]

            join_char = " "
            # obs = val.x[idx][0]
            print("----------")
            print(
                "id:%s | loss:%s | true:%s | pred:%s)\n"
                % (idx, round(loss, 2), truth, pred)
            )
        return

    def _prepare(self, data, train=True):
        """
        ```
        prepare data as tf.Dataset
        ```
        """
        # HF_EXCEPTION
        # convert arrays to TF dataset (iterator) on-the-fly
        # to work around issues with transformers and tf.Datasets
        if data is None:
            return None
        return data.to_tfdataset(train=train)

    def predict(self, val_data=None):
        """
        ```
        Makes predictions on validation set
        ```
        """
        if val_data is not None:
            val = val_data
        else:
            val = self.val_data
        if val is None:
            raise Exception("val_data must be supplied to get_learner or predict")
        if hasattr(val, "reset"):
            val.reset()
        classification, multilabel = U.is_classifier(self.model)
        preds = self.model.predict(self._prepare(val, train=False))
        if hasattr(
            preds, "logits"
        ):  # dep_fix: breaking change in transformers==4.0.0 - also needed for Longformer
            # if type(preds).__name__ == 'TFSequenceClassifierOutput': # dep_fix: undocumented breaking change in transformers==4.0.0
            # REFERENCE: https://discuss.huggingface.co/t/new-model-output-types/195
            preds = preds.logits

        # dep_fix: transformers in TF 2.2.0 returns a tuple insead of NumPy array for some reason
        if isinstance(preds, tuple) and len(preds) == 1:
            preds = preds[0]

        if classification:
            if multilabel:
                return keras.activations.sigmoid(tf.convert_to_tensor(preds)).numpy()
            else:
                return keras.activations.softmax(tf.convert_to_tensor(preds)).numpy()
        else:
            return preds

    def save_model(self, fpath):
        """
        ```
        save Transformers model
        ```
        """
        self._make_model_folder(fpath)
        self.model.save_pretrained(fpath)
        return

    # 2020-07-07: removed, as core.Learner.load_model calls TransformerPreprocessor.load_model_and_configure
    # def load_model(self, fpath, preproc=None):
    #    """
    #    load Transformers model
    #    Args:
    #      fpath(str): path to folder containing model files
    #      preproc(TransformerPreprocessor): a TransformerPreprocessor instance.
    #    """
    #    if preproc is None or not isinstance(preproc, TransformersPreprocessor):
    #        raise ValueError('preproc arg is required to load Transformer models from disk. ' +\
    #                          'Supply a TransformersPreprocessor instance. This is ' +\
    #                          'either the third return value from texts_from* function or '+\
    #                          'the result of calling ktrain.text.Transformer')

    #    self.model = _load_model(fpath, preproc=preproc)
    #    return

Classes

class BERTTextClassLearner (model, train_data=None, val_data=None, batch_size=32, eval_batch_size=32, workers=1, use_multiprocessing=False)
Main class used to tune and train Keras models for text classification using Array data.
Expand source code
class BERTTextClassLearner(ArrayLearner):
    """
    ```
    Main class used to tune and train Keras models for text classification using Array data.
    ```
    """

    def __init__(
        self,
        model,
        train_data=None,
        val_data=None,
        batch_size=U.DEFAULT_BS,
        eval_batch_size=U.DEFAULT_BS,
        workers=1,
        use_multiprocessing=False,
    ):
        super().__init__(
            model,
            train_data=train_data,
            val_data=val_data,
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )
        return

    def view_top_losses(self, n=4, preproc=None, val_data=None):
        """
        ```
        Views observations with top losses in validation set.
        Args:
         n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
         preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                                 For some data like text data, a preprocessor
                                 is required to undo the pre-processing
                                 to correctly view raw data.
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.
        ```
        """
        val = self._check_val(val_data)

        # get top losses and associated data
        tups = self.top_losses(n=n, val_data=val, preproc=preproc)

        # get multilabel status and class names
        classes = preproc.get_classes() if preproc is not None else None

        # iterate through losses
        for tup in tups:
            # get data
            idx = tup[0]
            loss = tup[1]
            truth = tup[2]
            pred = tup[3]

            # BERT-style tuple
            join_char = " "
            obs = val[0][0][idx]
            if preproc is not None:
                obs = preproc.undo(obs)
                if preproc.is_nospace_lang():
                    join_char = ""
            if type(obs) == str:
                obs = join_char.join(obs.split()[:512])
            print("----------")
            print(
                "id:%s | loss:%s | true:%s | pred:%s)\n"
                % (idx, round(loss, 2), truth, pred)
            )
            print(obs)
        return

Ancestors

Methods

def view_top_losses(self, n=4, preproc=None, val_data=None)
Views observations with top losses in validation set.
Args:
 n(int or tuple): a range to select in form of int or tuple
                  e.g., n=8 is treated as n=(0,8)
 preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                         For some data like text data, a preprocessor
                         is required to undo the pre-processing
                         to correctly view raw data.
  val_data:  optional val_data to use instead of self.val_data
Returns:
    list of n tuples where first element is either
    filepath or id of validation example and second element
    is loss.
Expand source code
def view_top_losses(self, n=4, preproc=None, val_data=None):
    """
    ```
    Views observations with top losses in validation set.
    Args:
     n(int or tuple): a range to select in form of int or tuple
                      e.g., n=8 is treated as n=(0,8)
     preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                             For some data like text data, a preprocessor
                             is required to undo the pre-processing
                             to correctly view raw data.
      val_data:  optional val_data to use instead of self.val_data
    Returns:
        list of n tuples where first element is either
        filepath or id of validation example and second element
        is loss.
    ```
    """
    val = self._check_val(val_data)

    # get top losses and associated data
    tups = self.top_losses(n=n, val_data=val, preproc=preproc)

    # get multilabel status and class names
    classes = preproc.get_classes() if preproc is not None else None

    # iterate through losses
    for tup in tups:
        # get data
        idx = tup[0]
        loss = tup[1]
        truth = tup[2]
        pred = tup[3]

        # BERT-style tuple
        join_char = " "
        obs = val[0][0][idx]
        if preproc is not None:
            obs = preproc.undo(obs)
            if preproc.is_nospace_lang():
                join_char = ""
        if type(obs) == str:
            obs = join_char.join(obs.split()[:512])
        print("----------")
        print(
            "id:%s | loss:%s | true:%s | pred:%s)\n"
            % (idx, round(loss, 2), truth, pred)
        )
        print(obs)
    return

Inherited members

class TransformerTextClassLearner (model, train_data=None, val_data=None, batch_size=32, eval_batch_size=32, workers=1, use_multiprocessing=False)
Main class used to tune and train Keras models for text classification using Array data.
Expand source code
class TransformerTextClassLearner(GenLearner):
    """
    ```
    Main class used to tune and train Keras models for text classification using Array data.
    ```
    """

    def __init__(
        self,
        model,
        train_data=None,
        val_data=None,
        batch_size=U.DEFAULT_BS,
        eval_batch_size=U.DEFAULT_BS,
        workers=1,
        use_multiprocessing=False,
    ):
        super().__init__(
            model,
            train_data=train_data,
            val_data=val_data,
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )
        return

    def view_top_losses(self, n=4, preproc=None, val_data=None):
        """
        ```
        Views observations with top losses in validation set.
        Args:
         n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
         preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                                 For some data like text data, a preprocessor
                                 is required to undo the pre-processing
                                 to correctly view raw data.
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.
        ```
        """
        val = self._check_val(val_data)

        # get top losses and associated data
        tups = self.top_losses(n=n, val_data=val, preproc=preproc)

        # get multilabel status and class names
        classes = preproc.get_classes() if preproc is not None else None

        # iterate through losses
        for tup in tups:
            # get data
            idx = tup[0]
            loss = tup[1]
            truth = tup[2]
            pred = tup[3]

            join_char = " "
            # obs = val.x[idx][0]
            print("----------")
            print(
                "id:%s | loss:%s | true:%s | pred:%s)\n"
                % (idx, round(loss, 2), truth, pred)
            )
        return

    def _prepare(self, data, train=True):
        """
        ```
        prepare data as tf.Dataset
        ```
        """
        # HF_EXCEPTION
        # convert arrays to TF dataset (iterator) on-the-fly
        # to work around issues with transformers and tf.Datasets
        if data is None:
            return None
        return data.to_tfdataset(train=train)

    def predict(self, val_data=None):
        """
        ```
        Makes predictions on validation set
        ```
        """
        if val_data is not None:
            val = val_data
        else:
            val = self.val_data
        if val is None:
            raise Exception("val_data must be supplied to get_learner or predict")
        if hasattr(val, "reset"):
            val.reset()
        classification, multilabel = U.is_classifier(self.model)
        preds = self.model.predict(self._prepare(val, train=False))
        if hasattr(
            preds, "logits"
        ):  # dep_fix: breaking change in transformers==4.0.0 - also needed for Longformer
            # if type(preds).__name__ == 'TFSequenceClassifierOutput': # dep_fix: undocumented breaking change in transformers==4.0.0
            # REFERENCE: https://discuss.huggingface.co/t/new-model-output-types/195
            preds = preds.logits

        # dep_fix: transformers in TF 2.2.0 returns a tuple insead of NumPy array for some reason
        if isinstance(preds, tuple) and len(preds) == 1:
            preds = preds[0]

        if classification:
            if multilabel:
                return keras.activations.sigmoid(tf.convert_to_tensor(preds)).numpy()
            else:
                return keras.activations.softmax(tf.convert_to_tensor(preds)).numpy()
        else:
            return preds

    def save_model(self, fpath):
        """
        ```
        save Transformers model
        ```
        """
        self._make_model_folder(fpath)
        self.model.save_pretrained(fpath)
        return

    # 2020-07-07: removed, as core.Learner.load_model calls TransformerPreprocessor.load_model_and_configure
    # def load_model(self, fpath, preproc=None):
    #    """
    #    load Transformers model
    #    Args:
    #      fpath(str): path to folder containing model files
    #      preproc(TransformerPreprocessor): a TransformerPreprocessor instance.
    #    """
    #    if preproc is None or not isinstance(preproc, TransformersPreprocessor):
    #        raise ValueError('preproc arg is required to load Transformer models from disk. ' +\
    #                          'Supply a TransformersPreprocessor instance. This is ' +\
    #                          'either the third return value from texts_from* function or '+\
    #                          'the result of calling ktrain.text.Transformer')

    #    self.model = _load_model(fpath, preproc=preproc)
    #    return

Ancestors

Methods

def save_model(self, fpath)
save Transformers model
Expand source code
def save_model(self, fpath):
    """
    ```
    save Transformers model
    ```
    """
    self._make_model_folder(fpath)
    self.model.save_pretrained(fpath)
    return
def view_top_losses(self, n=4, preproc=None, val_data=None)
Views observations with top losses in validation set.
Args:
 n(int or tuple): a range to select in form of int or tuple
                  e.g., n=8 is treated as n=(0,8)
 preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                         For some data like text data, a preprocessor
                         is required to undo the pre-processing
                         to correctly view raw data.
  val_data:  optional val_data to use instead of self.val_data
Returns:
    list of n tuples where first element is either
    filepath or id of validation example and second element
    is loss.
Expand source code
def view_top_losses(self, n=4, preproc=None, val_data=None):
    """
    ```
    Views observations with top losses in validation set.
    Args:
     n(int or tuple): a range to select in form of int or tuple
                      e.g., n=8 is treated as n=(0,8)
     preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                             For some data like text data, a preprocessor
                             is required to undo the pre-processing
                             to correctly view raw data.
      val_data:  optional val_data to use instead of self.val_data
    Returns:
        list of n tuples where first element is either
        filepath or id of validation example and second element
        is loss.
    ```
    """
    val = self._check_val(val_data)

    # get top losses and associated data
    tups = self.top_losses(n=n, val_data=val, preproc=preproc)

    # get multilabel status and class names
    classes = preproc.get_classes() if preproc is not None else None

    # iterate through losses
    for tup in tups:
        # get data
        idx = tup[0]
        loss = tup[1]
        truth = tup[2]
        pred = tup[3]

        join_char = " "
        # obs = val.x[idx][0]
        print("----------")
        print(
            "id:%s | loss:%s | true:%s | pred:%s)\n"
            % (idx, round(loss, 2), truth, pred)
        )
    return

Inherited members