Module ktrain.vision.object_detection.core

Expand source code
from transformers import pipeline

from ... import imports as I
from ...torch_base import TorchBase


class ObjectDetector(TorchBase):
    """
    interface to Image Captioner
    """

    def __init__(self, device=None, classification=False, threshold=0.9):
        """
        ```
        Object detection constructor

        Args:
          device(str): device to use (e.g., 'cuda', 'cpu')
          threshold(float):  threshold for object detection
          classification(bool): If True, simpy do image classification
        ```
        """
        if not I.PIL_INSTALLED:
            raise Exception(
                "PIL is not installed. Please install with: pip install pillow>=9.0.1"
            )

        super().__init__(
            device=device, quantize=False, min_transformers_version="4.12.3"
        )
        self.pipeline = pipeline(
            "image-classification" if classification else "object-detection",
            device=self.device_to_id(),
        )
        self.threshold = threshold
        self.classification = classification

    def detect(self, images, flatten=False, workers=0):
        """
        ```
        Performs object detection

        This method supports a single image or a list of images. If the input is an image, the return
        type is a string. If text is a list, a list of strings is returned
        Args:
            images: image|list
            flatten: flatten output to a list of objects
            workers: number of concurrent workers to use for processing data, defaults to None
        Returns:
            list of (label, score)

        ```
        """
        # Convert single element to list
        values = [images] if not isinstance(images, list) else images

        # Open images if file strings
        values = [
            I.Image.open(image) if isinstance(image, str) else image for image in values
        ]

        # Run pipeline
        results = (
            self.pipeline(values, num_workers=workers)
            if self.classification
            else self.pipeline(values, threshold=self.threshold, num_workers=workers)
        )

        # Build list of (id, score)
        outputs = []
        for result in results:
            # Convert to (label, score) tuples
            result = [
                (x["label"], x["score"]) for x in result if x["score"] > self.threshold
            ]

            # Sort by score descending
            result = sorted(result, key=lambda x: x[1], reverse=True)

            # Deduplicate labels
            unique = set()
            elements = []
            for label, score in result:
                if label not in unique:
                    elements.append(label if flatten else (label, score))
                    unique.add(label)

            outputs.append(elements)

        # Return single element if single element passed in
        return outputs[0] if not isinstance(images, list) else outputs

Classes

class ObjectDetector (device=None, classification=False, threshold=0.9)

interface to Image Captioner

Object detection constructor

Args:
  device(str): device to use (e.g., 'cuda', 'cpu')
  threshold(float):  threshold for object detection
  classification(bool): If True, simpy do image classification
Expand source code
class ObjectDetector(TorchBase):
    """
    interface to Image Captioner
    """

    def __init__(self, device=None, classification=False, threshold=0.9):
        """
        ```
        Object detection constructor

        Args:
          device(str): device to use (e.g., 'cuda', 'cpu')
          threshold(float):  threshold for object detection
          classification(bool): If True, simpy do image classification
        ```
        """
        if not I.PIL_INSTALLED:
            raise Exception(
                "PIL is not installed. Please install with: pip install pillow>=9.0.1"
            )

        super().__init__(
            device=device, quantize=False, min_transformers_version="4.12.3"
        )
        self.pipeline = pipeline(
            "image-classification" if classification else "object-detection",
            device=self.device_to_id(),
        )
        self.threshold = threshold
        self.classification = classification

    def detect(self, images, flatten=False, workers=0):
        """
        ```
        Performs object detection

        This method supports a single image or a list of images. If the input is an image, the return
        type is a string. If text is a list, a list of strings is returned
        Args:
            images: image|list
            flatten: flatten output to a list of objects
            workers: number of concurrent workers to use for processing data, defaults to None
        Returns:
            list of (label, score)

        ```
        """
        # Convert single element to list
        values = [images] if not isinstance(images, list) else images

        # Open images if file strings
        values = [
            I.Image.open(image) if isinstance(image, str) else image for image in values
        ]

        # Run pipeline
        results = (
            self.pipeline(values, num_workers=workers)
            if self.classification
            else self.pipeline(values, threshold=self.threshold, num_workers=workers)
        )

        # Build list of (id, score)
        outputs = []
        for result in results:
            # Convert to (label, score) tuples
            result = [
                (x["label"], x["score"]) for x in result if x["score"] > self.threshold
            ]

            # Sort by score descending
            result = sorted(result, key=lambda x: x[1], reverse=True)

            # Deduplicate labels
            unique = set()
            elements = []
            for label, score in result:
                if label not in unique:
                    elements.append(label if flatten else (label, score))
                    unique.add(label)

            outputs.append(elements)

        # Return single element if single element passed in
        return outputs[0] if not isinstance(images, list) else outputs

Ancestors

Methods

def detect(self, images, flatten=False, workers=0)
Performs object detection

This method supports a single image or a list of images. If the input is an image, the return
type is a string. If text is a list, a list of strings is returned
Args:
    images: image|list
    flatten: flatten output to a list of objects
    workers: number of concurrent workers to use for processing data, defaults to None
Returns:
    list of (label, score)

Expand source code
def detect(self, images, flatten=False, workers=0):
    """
    ```
    Performs object detection

    This method supports a single image or a list of images. If the input is an image, the return
    type is a string. If text is a list, a list of strings is returned
    Args:
        images: image|list
        flatten: flatten output to a list of objects
        workers: number of concurrent workers to use for processing data, defaults to None
    Returns:
        list of (label, score)

    ```
    """
    # Convert single element to list
    values = [images] if not isinstance(images, list) else images

    # Open images if file strings
    values = [
        I.Image.open(image) if isinstance(image, str) else image for image in values
    ]

    # Run pipeline
    results = (
        self.pipeline(values, num_workers=workers)
        if self.classification
        else self.pipeline(values, threshold=self.threshold, num_workers=workers)
    )

    # Build list of (id, score)
    outputs = []
    for result in results:
        # Convert to (label, score) tuples
        result = [
            (x["label"], x["score"]) for x in result if x["score"] > self.threshold
        ]

        # Sort by score descending
        result = sorted(result, key=lambda x: x[1], reverse=True)

        # Deduplicate labels
        unique = set()
        elements = []
        for label, score in result:
            if label not in unique:
                elements.append(label if flatten else (label, score))
                unique.add(label)

        outputs.append(elements)

    # Return single element if single element passed in
    return outputs[0] if not isinstance(images, list) else outputs

Inherited members