Source code for flytekitplugins.awssagemaker_inference.agent

import json
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Optional

import cloudpickle

from flytekit.extend.backend.base_agent import (
    AgentRegistry,
    AsyncAgentBase,
    Resource,
    ResourceMeta,
)
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from .boto3_mixin import Boto3AgentMixin


@dataclass
class SageMakerEndpointMetadata(ResourceMeta):
    config: Dict[str, Any]
    region: Optional[str] = None
    inputs: Optional[LiteralMap] = None

    def encode(self) -> bytes:
        return cloudpickle.dumps(self)

    @classmethod
    def decode(cls, data: bytes) -> "SageMakerEndpointMetadata":
        return cloudpickle.loads(data)


states = {
    "Creating": "Running",
    "InService": "Success",
    "Failed": "Failed",
}


class DateTimeEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, datetime):
            return o.isoformat()

        return json.JSONEncoder.default(self, o)


[docs] class SageMakerEndpointAgent(Boto3AgentMixin, AsyncAgentBase): """This agent creates an endpoint.""" name = "SageMaker Endpoint Agent" def __init__(self): super().__init__( service="sagemaker", task_type_name="sagemaker-endpoint", metadata_type=SageMakerEndpointMetadata, )
[docs] async def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs ) -> SageMakerEndpointMetadata: custom = task_template.custom config = custom.get("config") region = custom.get("region") await self._call( method="create_endpoint", config=config, inputs=inputs, region=region, ) return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs)
[docs] async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: endpoint_status = await self._call( method="describe_endpoint", config={"EndpointName": resource_meta.config.get("EndpointName")}, inputs=resource_meta.inputs, region=resource_meta.region, ) current_state = endpoint_status.get("EndpointStatus") flyte_phase = convert_to_flyte_phase(states[current_state]) message = None if current_state == "Failed": message = endpoint_status.get("FailureReason") res = None if current_state == "InService": res = {"result": json.dumps(endpoint_status, cls=DateTimeEncoder)} return Resource(phase=flyte_phase, outputs=res, message=message)
[docs] async def delete(self, resource_meta: SageMakerEndpointMetadata, **kwargs): await self._call( "delete_endpoint", config={"EndpointName": resource_meta.config.get("EndpointName")}, region=resource_meta.region, inputs=resource_meta.inputs, )
AgentRegistry.register(SageMakerEndpointAgent())