Source code for dds_cloudapi_sdk.tasks.trex_generic

"""
The TRex Generic Inference algorithm enables user prompting on multiple images and get the boxes, scores on one target image.

This algorithm hypothesis that there is only one category per batch image, and it does not support batch inference.

"""

from typing import List
from typing import Union

import pydantic

from dds_cloudapi_sdk.tasks.base import BaseTask
from dds_cloudapi_sdk.tasks.prompt import BatchPointPrompt
from dds_cloudapi_sdk.tasks.prompt import BatchRectPrompt
from dds_cloudapi_sdk.tasks.prompt import PromptType


[docs] class TRexObject(pydantic.BaseModel): """ The object detected by TRexGenericInfer task. :param score: the prediction score :param bbox: the bounding box, [upper_left_x, upper_left_y, lower_right_x, lower_right_y] """ score: float #: the prediction score bbox: List[float] #: the bounding box, [upper_left_x, upper_left_y, lower_right_x, lower_right_y]
[docs] class TaskResult(pydantic.BaseModel): """ The task result of TRexGenericInfer task. :param objects: a list of detected objects of :class:`TRexObject <dds_cloudapi_sdk.tasks.trex_generic.TRexObject>` """ objects: List[TRexObject] #: a list of detected objects of :class:`TRexObject <dds_cloudapi_sdk.tasks.trex_generic.TRexObject>`
[docs] class TRexGenericInfer(BaseTask): """ Trigger the Trex Generic Inference algorithm. This task can process prompts from multiple images, and each image can have several prompts. However, each task is limited to one type of prompt, either point or rect. :param image_url: the image to be inferred on. :param batch_prompts: list of :class:`BatchRectPrompt <dds_cloudapi_sdk.tasks.prompt.BatchRectPrompt>` objects or :class:`BatchPointPrompt <dds_cloudapi_sdk.tasks.prompt.BatchPointPrompt>`. """ def __init__(self, image_url: str, batch_prompts: Union[List[BatchRectPrompt], List[BatchPointPrompt]], ): super().__init__() self.image_url = image_url self.batch_prompts = batch_prompts @property def api_path(self): return "trex_generic_infer" @property def api_body(self): batch_prompts = { "prompts": [prompt.dict() for prompt in self.batch_prompts] } if isinstance(self.batch_prompts[0], BatchPointPrompt): batch_prompts["type"] = PromptType.Point.value elif isinstance(self.batch_prompts[0], BatchRectPrompt): batch_prompts["type"] = PromptType.Rect.value data = { "image" : self.image_url, "batch_prompts": batch_prompts } return data @property def result(self) -> TaskResult: """ Get the formatted :class:`TaskResult <dds_cloudapi_sdk.tasks.trex_generic.TaskResult>` object. """ return self._result def format_result(self, result: dict) -> TaskResult: return TaskResult(**result)
def test(): """ python -m dds_cloudapi_sdk.tasks.trex_generic """ import os test_token = os.environ["DDS_CLOUDAPI_TEST_TOKEN"] import logging logging.basicConfig(level=logging.INFO) from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Client config = Config(test_token) client = Client(config) batch_prompts = [ BatchRectPrompt( image="https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/ivp/04_b.jpg", rects=[[475.18413597733706, 550.1983002832861, 548.1019830028329, 599.915014164306]] ) ] task = TRexGenericInfer( image_url="https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/ivp/04_b.jpg", batch_prompts=batch_prompts ) client.run_task(task) for obj in task.result.objects: print(obj) break if __name__ == "__main__": test()