Source code for dds_cloudapi_sdk.tasks.gsam

"""
The Grounded SAM algorithm integrates the capabilities of both GroundingDINO and SAM algorithms, utilizing text prompts to
efficiently detect bounding boxes and segmentation masks.

This powerful algorithm is available in DDS CloudAPI SDK in two variants:

- the "tiny" model through TinyGSAMTask
- the "base" model through BaseGSAMTask
"""

import enum
from typing import List

import pydantic

from dds_cloudapi_sdk.tasks.base import BaseTask
from dds_cloudapi_sdk.tasks.prompt import TextPrompt


class ModelType(enum.Enum):
    """
    | This enumerator represents the models the GSAM algorithm uses.
    | The :class:`TinyGSAMTask <dds_cloudapi_sdk.tasks.gsam.TinyGSAMTask>` will use the Tiny model by default.
    | And the :class:`BaseGSAMTask <dds_cloudapi_sdk.tasks.gsam.BaseGSAMTask>` will use the Base model by default.
    """

    Tiny = "swint"  #: The tiny model
    Base = "swinb"  #: The base model


[docs] class GSAMObject(pydantic.BaseModel): """ The object detected by GSAM tasks. :param category: the category name of the object :param score: the predict score of the object :param bbox: the bbox of the object, [xmin, ymin, xmax, ymax] """ category: str #: the category name of the object score: float #: the predict score of the object bbox: List[float] #: the bbox of the object, [xmin, ymin, xmax, ymax]
[docs] class TaskResult(pydantic.BaseModel): """ The task result of the GSAM tasks. :param mask_url: an image url with all objects' mask drawn on :param objects: a list of detected objects of :class:`GSAMObject <dds_cloudapi_sdk.tasks.gsam.GSAMObject>` """ mask_url: str #: an image url with all objects' mask drawn on objects: List[GSAMObject] #: a list of detected objects of :class:`GSAMObject <dds_cloudapi_sdk.tasks.gsam.GSAMObject>`
class _GroundedSAMTask(BaseTask): def __init__(self, image_url: str, model_type: ModelType, prompts: List[TextPrompt] = None, box_threshold: float = 0.3, text_threshold: float = 0.2, nms_threshold: float = 0.8 ): """ Construct a task calls grounded_sam algorithm. """ for p in prompts: if not p.is_positive: raise ValueError(f"Only positive text prompt is permitted for GSAM task.") super().__init__() self.image_url = image_url self.model_type = model_type self.prompts = prompts self.box_threshold = box_threshold self.text_threshold = text_threshold self.nms_threshold = nms_threshold @property def api_path(self): return "grounded_sam" @property def api_body(self): data = { "image" : self.image_url, "model_type" : self.model_type.value, "prompts" : [prompt.dict() for prompt in self.prompts], "box_threshold" : self.box_threshold, "text_threshold": self.text_threshold, "nms_threshold" : self.nms_threshold } return data @property def result(self) -> TaskResult: """ Get the formatted :class:`TaskResult <dds_cloudapi_sdk.tasks.gsam.TaskResult>` object. """ return self._result def format_result(self, result: dict) -> TaskResult: """ Format the result of the task, return an instance of :class:`GSAM TaskResult <dds_cloudapi_sdk.tasks.gsam.TaskResult>`. An example of input result data:: { 'mask_url': 'https://image.png', 'objects': [ {'category': 'iron man', 'score': 0.4880097210407257, 'bbox': [653.0848388671875, 329.127685546875, 942.047119140625, 842.4909057617188]}, {'category': 'iron man', 'score': 0.3169572949409485, 'bbox': [481.4720153808594, 0.6030968427658081, 878.6416625976562, 650.6054077148438]} ] } :param result: the raw python dict returned by the API. """ return TaskResult(**result)
[docs] class TinyGSAMTask(_GroundedSAMTask): """ Trigger the Grounded-SegmentAnything algorithm with the **tiny** :class:`ModelType <dds_cloudapi_sdk.tasks.gsam.ModelType>`. :param image_url: the segmenting image url. :param prompts: a list of :class:`TextPrompt <dds_cloudapi_sdk.tasks.prompt.TextPrompt>` object. But for This task, only positive prompts are permitted. :param box_threshold: a threshold to filter out objects by bbox score, default to 0.3. :param text_threshold: a threshold to filter out objects by text score, default to 0.2. :param nms_threshold: a threshold for nms to filter out overlapping boxes, default to 0.8. """ def __init__(self, image_url: str, prompts: List[TextPrompt] = None, box_threshold: float = 0.3, text_threshold: float = 0.2, nms_threshold: float = 0.8 ): super().__init__(image_url, ModelType.Tiny, prompts, box_threshold, text_threshold, nms_threshold)
[docs] class BaseGSAMTask(_GroundedSAMTask): """ Trigger the Grounded-SegmentAnything algorithm with the **base** :class:`ModelType <dds_cloudapi_sdk.tasks.gsam.ModelType>`. :param image_url: the segmenting image url. :param prompts: a list of :class:`TextPrompt <dds_cloudapi_sdk.tasks.prompt.TextPrompt>` object. But for This task, only positive prompts are permitted. :param box_threshold: a threshold to filter out objects by bbox score, default to 0.3. :param text_threshold: a threshold to filter out objects by text score, default to 0.2. :param nms_threshold: a threshold for nms to filter out overlapping boxes, default to 0.8. """ def __init__(self, image_url: str, prompts: List[TextPrompt] = None, box_threshold: float = 0.3, text_threshold: float = 0.2, nms_threshold: float = 0.8 ): """ Construct a task calls grounded_sam algorithm with base model. """ super().__init__(image_url, ModelType.Base, prompts, box_threshold, text_threshold, nms_threshold)
def test(): """ python -m dds_cloudapi_sdk.tasks.gsam """ import os test_token = os.environ["DDS_CLOUDAPI_TEST_TOKEN"] import logging logging.basicConfig(level=logging.INFO) from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import TextPrompt config = Config(test_token) client = Client(config) task = TinyGSAMTask( "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/grounded_sam/iron_man.jpg", [TextPrompt(text="iron man")] ) client.run_task(task) print(task.result) task = BaseGSAMTask( "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/grounded_sam/iron_man.jpg", [TextPrompt(text="iron man")] ) client.run_task(task) print(task.result) if __name__ == "__main__": test()