Module ktrain.text.ner.learner

Expand source code
from ... import utils as U
from ...core import GenLearner
from ...imports import *
from . import metrics


class NERLearner(GenLearner):
    """
    Learner for Sequence Taggers.
    """

    def __init__(
        self,
        model,
        train_data=None,
        val_data=None,
        batch_size=U.DEFAULT_BS,
        eval_batch_size=U.DEFAULT_BS,
        workers=1,
        use_multiprocessing=False,
    ):
        super().__init__(
            model,
            train_data=train_data,
            val_data=val_data,
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )
        return

    def validate(self, val_data=None, print_report=True, class_names=[]):
        """
        Validate text sequence taggers
        """
        val = self._check_val(val_data)

        if not val.prepare_called:
            val.prepare()

        if not U.is_ner(model=self.model, data=val):
            warnings.warn("learner.validate_ner is only for sequence taggers.")
            return

        label_true = []
        label_pred = []
        for i in range(len(val)):
            x_true, y_true = val[i]
            # lengths = self.ner_lengths(y_true)
            lengths = val.get_lengths(i)
            y_pred = self.model.predict_on_batch(x_true)

            y_true = val.p.inverse_transform(y_true, lengths)
            y_pred = val.p.inverse_transform(y_pred, lengths)

            label_true.extend(y_true)
            label_pred.extend(y_pred)

        score = metrics.f1_score(label_true, label_pred)
        # acc = metrics.accuracy_score(label_true, label_pred)
        if print_report:
            print("   F1:  {:04.2f}".format(score * 100))
            # print('   ACC: {:04.2f}'.format(acc * 100))
            print(metrics.classification_report(label_true, label_pred))

        return score

    def top_losses(self, n=4, val_data=None, preproc=None):
        """
        Computes losses on validation set sorted by examples with top losses
        Args:
          n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.

        """
        val = self._check_val(val_data)
        if type(n) == type(42):
            n = (0, n)

        # get predicictions and ground truth
        y_pred = self.predict(val_data=val)
        y_true = self.ground_truth(val_data=val)

        # compute losses and sort
        losses = []
        for idx, y_t in enumerate(y_true):
            y_p = y_pred[idx]
            # err = 1- sum(1 for x,y in zip(y_t,y_p) if x == y) / len(y_t)
            err = sum(1 for x, y in zip(y_t, y_p) if x != y)
            losses.append(err)
        tups = [(i, x, y_true[i], y_pred[i]) for i, x in enumerate(losses) if x > 0]
        tups.sort(key=operator.itemgetter(1), reverse=True)

        # prune by given range
        tups = tups[n[0] : n[1]] if n is not None else tups
        return tups

    def view_top_losses(self, n=4, preproc=None, val_data=None):
        """
        Views observations with top losses in validation set.
        Args:
         n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
         preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                                 For some data like text data, a preprocessor
                                 is required to undo the pre-processing
                                 to correctly view raw data.
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.

        """

        # check validation data and arguments
        val = self._check_val(val_data)

        tups = self.top_losses(n=n, val_data=val)

        # get multilabel status and class names
        classes = preproc.get_classes() if preproc is not None else None

        # iterate through losses
        for tup in tups:
            # get data
            idx = tup[0]
            loss = tup[1]
            truth = tup[2]
            pred = tup[3]

            seq = val.x[idx]
            print("total incorrect: %s" % (loss))
            print("{:15} {:5}: ({})".format("Word", "True", "Pred"))
            print("=" * 30)
            for w, true_tag, pred_tag in zip(seq, truth, pred):
                print("{:15}:{:5} ({})".format(w, true_tag, pred_tag))
            print("\n")
        return

    def save_model(self, fpath):
        """
        a wrapper to model.save
        """
        self._make_model_folder(fpath)
        if U.is_crf(self.model):
            from .anago.layers import crf_loss

            self.model.compile(loss=crf_loss, optimizer=U.DEFAULT_OPT)
        self.model.save(os.path.join(fpath, U.MODEL_NAME), save_format="h5")
        return

    def predict(self, val_data=None):
        """
        Makes predictions on validation set
        """
        if val_data is not None:
            val = val_data
        else:
            val = self.val_data
        if val is None:
            raise Exception("val_data must be supplied to get_learner or predict")
        steps = np.ceil(U.nsamples_from_data(val) / val.batch_size)
        results = []
        for idx, (X, y) in enumerate(val):
            y_pred = self.model.predict_on_batch(X)
            lengths = val.get_lengths(idx)
            y_pred = val.p.inverse_transform(y_pred, lengths)
            results.extend(y_pred)
        return results

    def _prepare(self, data, train=True):
        """
        prepare NERSequence for training
        """
        if data is None:
            return None
        mode = "training" if train else "validation"
        if not data.prepare_called:
            print("preparing %s data ..." % (mode), end="")
            data.prepare()
            print("done.")
        return data

Classes

class NERLearner (model, train_data=None, val_data=None, batch_size=32, eval_batch_size=32, workers=1, use_multiprocessing=False)

Learner for Sequence Taggers.

Expand source code
class NERLearner(GenLearner):
    """
    Learner for Sequence Taggers.
    """

    def __init__(
        self,
        model,
        train_data=None,
        val_data=None,
        batch_size=U.DEFAULT_BS,
        eval_batch_size=U.DEFAULT_BS,
        workers=1,
        use_multiprocessing=False,
    ):
        super().__init__(
            model,
            train_data=train_data,
            val_data=val_data,
            batch_size=batch_size,
            eval_batch_size=eval_batch_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )
        return

    def validate(self, val_data=None, print_report=True, class_names=[]):
        """
        Validate text sequence taggers
        """
        val = self._check_val(val_data)

        if not val.prepare_called:
            val.prepare()

        if not U.is_ner(model=self.model, data=val):
            warnings.warn("learner.validate_ner is only for sequence taggers.")
            return

        label_true = []
        label_pred = []
        for i in range(len(val)):
            x_true, y_true = val[i]
            # lengths = self.ner_lengths(y_true)
            lengths = val.get_lengths(i)
            y_pred = self.model.predict_on_batch(x_true)

            y_true = val.p.inverse_transform(y_true, lengths)
            y_pred = val.p.inverse_transform(y_pred, lengths)

            label_true.extend(y_true)
            label_pred.extend(y_pred)

        score = metrics.f1_score(label_true, label_pred)
        # acc = metrics.accuracy_score(label_true, label_pred)
        if print_report:
            print("   F1:  {:04.2f}".format(score * 100))
            # print('   ACC: {:04.2f}'.format(acc * 100))
            print(metrics.classification_report(label_true, label_pred))

        return score

    def top_losses(self, n=4, val_data=None, preproc=None):
        """
        Computes losses on validation set sorted by examples with top losses
        Args:
          n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.

        """
        val = self._check_val(val_data)
        if type(n) == type(42):
            n = (0, n)

        # get predicictions and ground truth
        y_pred = self.predict(val_data=val)
        y_true = self.ground_truth(val_data=val)

        # compute losses and sort
        losses = []
        for idx, y_t in enumerate(y_true):
            y_p = y_pred[idx]
            # err = 1- sum(1 for x,y in zip(y_t,y_p) if x == y) / len(y_t)
            err = sum(1 for x, y in zip(y_t, y_p) if x != y)
            losses.append(err)
        tups = [(i, x, y_true[i], y_pred[i]) for i, x in enumerate(losses) if x > 0]
        tups.sort(key=operator.itemgetter(1), reverse=True)

        # prune by given range
        tups = tups[n[0] : n[1]] if n is not None else tups
        return tups

    def view_top_losses(self, n=4, preproc=None, val_data=None):
        """
        Views observations with top losses in validation set.
        Args:
         n(int or tuple): a range to select in form of int or tuple
                          e.g., n=8 is treated as n=(0,8)
         preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                                 For some data like text data, a preprocessor
                                 is required to undo the pre-processing
                                 to correctly view raw data.
          val_data:  optional val_data to use instead of self.val_data
        Returns:
            list of n tuples where first element is either
            filepath or id of validation example and second element
            is loss.

        """

        # check validation data and arguments
        val = self._check_val(val_data)

        tups = self.top_losses(n=n, val_data=val)

        # get multilabel status and class names
        classes = preproc.get_classes() if preproc is not None else None

        # iterate through losses
        for tup in tups:
            # get data
            idx = tup[0]
            loss = tup[1]
            truth = tup[2]
            pred = tup[3]

            seq = val.x[idx]
            print("total incorrect: %s" % (loss))
            print("{:15} {:5}: ({})".format("Word", "True", "Pred"))
            print("=" * 30)
            for w, true_tag, pred_tag in zip(seq, truth, pred):
                print("{:15}:{:5} ({})".format(w, true_tag, pred_tag))
            print("\n")
        return

    def save_model(self, fpath):
        """
        a wrapper to model.save
        """
        self._make_model_folder(fpath)
        if U.is_crf(self.model):
            from .anago.layers import crf_loss

            self.model.compile(loss=crf_loss, optimizer=U.DEFAULT_OPT)
        self.model.save(os.path.join(fpath, U.MODEL_NAME), save_format="h5")
        return

    def predict(self, val_data=None):
        """
        Makes predictions on validation set
        """
        if val_data is not None:
            val = val_data
        else:
            val = self.val_data
        if val is None:
            raise Exception("val_data must be supplied to get_learner or predict")
        steps = np.ceil(U.nsamples_from_data(val) / val.batch_size)
        results = []
        for idx, (X, y) in enumerate(val):
            y_pred = self.model.predict_on_batch(X)
            lengths = val.get_lengths(idx)
            y_pred = val.p.inverse_transform(y_pred, lengths)
            results.extend(y_pred)
        return results

    def _prepare(self, data, train=True):
        """
        prepare NERSequence for training
        """
        if data is None:
            return None
        mode = "training" if train else "validation"
        if not data.prepare_called:
            print("preparing %s data ..." % (mode), end="")
            data.prepare()
            print("done.")
        return data

Ancestors

Methods

def predict(self, val_data=None)

Makes predictions on validation set

Expand source code
def predict(self, val_data=None):
    """
    Makes predictions on validation set
    """
    if val_data is not None:
        val = val_data
    else:
        val = self.val_data
    if val is None:
        raise Exception("val_data must be supplied to get_learner or predict")
    steps = np.ceil(U.nsamples_from_data(val) / val.batch_size)
    results = []
    for idx, (X, y) in enumerate(val):
        y_pred = self.model.predict_on_batch(X)
        lengths = val.get_lengths(idx)
        y_pred = val.p.inverse_transform(y_pred, lengths)
        results.extend(y_pred)
    return results
def save_model(self, fpath)

a wrapper to model.save

Expand source code
def save_model(self, fpath):
    """
    a wrapper to model.save
    """
    self._make_model_folder(fpath)
    if U.is_crf(self.model):
        from .anago.layers import crf_loss

        self.model.compile(loss=crf_loss, optimizer=U.DEFAULT_OPT)
    self.model.save(os.path.join(fpath, U.MODEL_NAME), save_format="h5")
    return
def top_losses(self, n=4, val_data=None, preproc=None)

Computes losses on validation set sorted by examples with top losses

Args

n(int or tuple): a range to select in form of int or tuple
e.g., n=8 is treated as n=(0,8)
val_data
optional val_data to use instead of self.val_data

Returns

list of n tuples where first element is either filepath or id of validation example and second element is loss.

Expand source code
def top_losses(self, n=4, val_data=None, preproc=None):
    """
    Computes losses on validation set sorted by examples with top losses
    Args:
      n(int or tuple): a range to select in form of int or tuple
                      e.g., n=8 is treated as n=(0,8)
      val_data:  optional val_data to use instead of self.val_data
    Returns:
        list of n tuples where first element is either
        filepath or id of validation example and second element
        is loss.

    """
    val = self._check_val(val_data)
    if type(n) == type(42):
        n = (0, n)

    # get predicictions and ground truth
    y_pred = self.predict(val_data=val)
    y_true = self.ground_truth(val_data=val)

    # compute losses and sort
    losses = []
    for idx, y_t in enumerate(y_true):
        y_p = y_pred[idx]
        # err = 1- sum(1 for x,y in zip(y_t,y_p) if x == y) / len(y_t)
        err = sum(1 for x, y in zip(y_t, y_p) if x != y)
        losses.append(err)
    tups = [(i, x, y_true[i], y_pred[i]) for i, x in enumerate(losses) if x > 0]
    tups.sort(key=operator.itemgetter(1), reverse=True)

    # prune by given range
    tups = tups[n[0] : n[1]] if n is not None else tups
    return tups
def validate(self, val_data=None, print_report=True, class_names=[])

Validate text sequence taggers

Expand source code
def validate(self, val_data=None, print_report=True, class_names=[]):
    """
    Validate text sequence taggers
    """
    val = self._check_val(val_data)

    if not val.prepare_called:
        val.prepare()

    if not U.is_ner(model=self.model, data=val):
        warnings.warn("learner.validate_ner is only for sequence taggers.")
        return

    label_true = []
    label_pred = []
    for i in range(len(val)):
        x_true, y_true = val[i]
        # lengths = self.ner_lengths(y_true)
        lengths = val.get_lengths(i)
        y_pred = self.model.predict_on_batch(x_true)

        y_true = val.p.inverse_transform(y_true, lengths)
        y_pred = val.p.inverse_transform(y_pred, lengths)

        label_true.extend(y_true)
        label_pred.extend(y_pred)

    score = metrics.f1_score(label_true, label_pred)
    # acc = metrics.accuracy_score(label_true, label_pred)
    if print_report:
        print("   F1:  {:04.2f}".format(score * 100))
        # print('   ACC: {:04.2f}'.format(acc * 100))
        print(metrics.classification_report(label_true, label_pred))

    return score
def view_top_losses(self, n=4, preproc=None, val_data=None)

Views observations with top losses in validation set. Args: n(int or tuple): a range to select in form of int or tuple e.g., n=8 is treated as n=(0,8) preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor. For some data like text data, a preprocessor is required to undo the pre-processing to correctly view raw data. val_data: optional val_data to use instead of self.val_data

Returns

list of n tuples where first element is either filepath or id of validation example and second element is loss.

Expand source code
def view_top_losses(self, n=4, preproc=None, val_data=None):
    """
    Views observations with top losses in validation set.
    Args:
     n(int or tuple): a range to select in form of int or tuple
                      e.g., n=8 is treated as n=(0,8)
     preproc (Preprocessor): A TextPreprocessor or ImagePreprocessor.
                             For some data like text data, a preprocessor
                             is required to undo the pre-processing
                             to correctly view raw data.
      val_data:  optional val_data to use instead of self.val_data
    Returns:
        list of n tuples where first element is either
        filepath or id of validation example and second element
        is loss.

    """

    # check validation data and arguments
    val = self._check_val(val_data)

    tups = self.top_losses(n=n, val_data=val)

    # get multilabel status and class names
    classes = preproc.get_classes() if preproc is not None else None

    # iterate through losses
    for tup in tups:
        # get data
        idx = tup[0]
        loss = tup[1]
        truth = tup[2]
        pred = tup[3]

        seq = val.x[idx]
        print("total incorrect: %s" % (loss))
        print("{:15} {:5}: ({})".format("Word", "True", "Pred"))
        print("=" * 30)
        for w, true_tag, pred_tag in zip(seq, truth, pred):
            print("{:15}:{:5} ({})".format(w, true_tag, pred_tag))
        print("\n")
    return

Inherited members