Source code for dds_cloudapi_sdk.tasks.trex_embd_infer

"""
The TRex Embedding Inference algorithm enables user inferring image and get the boxes and scores on the same image by embd files
they trained from the :class:`Embedding Customization <dds_cloudapi_sdk.tasks.trex_embd_customize.TRexEmbdCustomize>`.

This algorithm supports batch inference for multiple images, and for every image, multiple embeddings are supported.
"""

from typing import List

import pydantic

from dds_cloudapi_sdk.tasks.base import BaseTask
from dds_cloudapi_sdk.tasks.prompt import BatchEmbdInfer
from dds_cloudapi_sdk.tasks.prompt import BatchEmbdPrompt
from dds_cloudapi_sdk.tasks.prompt import BatchRectPrompt
from dds_cloudapi_sdk.tasks.prompt import PromptType
from dds_cloudapi_sdk.tasks.trex_embd_customize import TRexEmbdCustomize


[docs] class TRexObject(pydantic.BaseModel): """ The object detected by TRexEmbdInfer task. :param score: the prediction score :param bbox: the bounding box, [upper_left_x, upper_left_y, lower_right_x, lower_right_y] :param category_id: the category id of the object """ score: float #: the prediction score bbox: List[float] #: the bounding box, [upper_left_x, upper_left_y, lower_right_x, lower_right_y] category_id: int #: the category id of the object
[docs] class TaskResult(pydantic.BaseModel): """ The task result of TRexInteractiveInfer task. :param object_batches: a 2D list of detected objects of :class:`TRexObject <dds_cloudapi_sdk.tasks.trex_embd_infer.TRexObject>`, each inner list is the detected objects of one image """ object_batches: List[List[TRexObject]] #: a 2D list of detected objects of :class:`TRexObject <dds_cloudapi_sdk.tasks.trex_embd_infer.TRexObject>`, each inner list is the detected objects of one image
[docs] class TRexEmbdInfer(BaseTask): """ Trigger the Trex Embedding Inference algorithm. This task can process prompts from multiple images, and each image can have several embedding prompts. :param batch_infers: list of :class:`BatchPointInfer <dds_cloudapi_sdk.tasks.prompt.BatchPointInfer>` objects or :class:`BatchRectInfer <dds_cloudapi_sdk.tasks.prompt.BatchRectInfer>`. """ def __init__(self, batch_infers: List[BatchEmbdInfer], ): super().__init__() self.batch_infers = batch_infers @property def api_path(self): return "trex_embd_infer" @property def api_body(self): batch_infers = [] for infer in self.batch_infers: infer_data = infer.dict() infer_data["prompt_type"] = PromptType.Embd.value batch_infers.append(infer_data) print(batch_infers) return {"batch_infers": batch_infers} @property def result(self) -> TaskResult: """ Get the formatted :class:`TaskResult <dds_cloudapi_sdk.tasks.trex_embd_infer.TaskResult>` object. """ return self._result def format_result(self, result: dict) -> TaskResult: return TaskResult(**result)
def test(): """ python -m dds_cloudapi_sdk.tasks.trex_embd """ import os test_token = os.environ["DDS_CLOUDAPI_TEST_TOKEN"] import logging logging.basicConfig(level=logging.INFO) from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Client config = Config(test_token) client = Client(config) image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/ivp/04_b.jpg" batch_prompts = [ BatchRectPrompt( image=image_url, rects=[[475.18413597733706, 550.1983002832861, 548.1019830028329, 599.915014164306]] ) ] task = TRexEmbdCustomize( batch_prompts=batch_prompts ) client.run_task(task) embd_url = task.result.embd print(embd_url) infer_1 = BatchEmbdInfer( image=image_url, prompts=[ BatchEmbdPrompt(embd=embd_url, category_id=1) ] ) task = TRexEmbdInfer([infer_1]) client.run_task(task) for image_objects in task.result.object_batches: for obj in image_objects: print(obj.score) print(obj.bbox) print(obj.category_id) break break if __name__ == "__main__": test()