Source code for flytekitplugins.pandera.schema
import typing
from typing import Type
from flytekit import FlyteContext, lazy_module
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.schema import FlyteSchema, SchemaFormat, SchemaOpenMode
from flytekit.types.schema.types import FlyteSchemaTransformer
from flytekit.types.schema.types_pandas import PandasSchemaWriter
pandas = lazy_module("pandas")
pandera = lazy_module("pandera")
T = typing.TypeVar("T")
[docs]
class PanderaTransformer(TypeTransformer[pandera.typing.DataFrame]):
_SUPPORTED_TYPES: typing.Dict[
type, SchemaType.SchemaColumn.SchemaColumnType
] = FlyteSchemaTransformer._SUPPORTED_TYPES
def __init__(self):
super().__init__("Pandera Transformer", pandera.typing.DataFrame) # type: ignore
def _pandera_schema(self, t: Type[pandera.typing.DataFrame]):
try:
type_args = typing.get_args(t)
except AttributeError:
# for python < 3.8
type_args = getattr(t, "__args__", None)
if type_args:
schema_model, *_ = type_args
schema = schema_model.to_schema()
else:
schema = pandera.DataFrameSchema() # type: ignore
return schema
@staticmethod
def _get_pandas_type(pandera_dtype: pandera.dtypes.DataType):
return pandera_dtype.type.type
def _get_col_dtypes(self, t: Type[pandera.typing.DataFrame]):
return {k: self._get_pandas_type(v.dtype) for k, v in self._pandera_schema(t).columns.items()}
def _get_schema_type(self, t: Type[pandera.typing.DataFrame]) -> SchemaType:
converted_cols: typing.List[SchemaType.SchemaColumn] = []
for k, col in self._pandera_schema(t).columns.items():
pandas_type = self._get_pandas_type(col.dtype)
if pandas_type not in self._SUPPORTED_TYPES:
raise AssertionError(f"type {pandas_type} is currently not supported by the flytekit-pandera plugin")
converted_cols.append(SchemaType.SchemaColumn(name=k, type=self._SUPPORTED_TYPES[pandas_type]))
return SchemaType(columns=converted_cols)
[docs]
def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType:
return LiteralType(schema=self._get_schema_type(t))
[docs]
def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, (t, pandas.DataFrame)):
raise TypeError(f"Type of Val '{v}' is not an instance of {t}")
[docs]
def to_literal(
self,
ctx: FlyteContext,
python_val: pandas.DataFrame,
python_type: Type[pandera.typing.DataFrame],
expected: LiteralType,
) -> Literal:
if isinstance(python_val, pandas.DataFrame):
local_dir = ctx.file_access.get_random_local_directory()
w = PandasSchemaWriter(
local_dir=local_dir, cols=self._get_col_dtypes(python_type), fmt=SchemaFormat.PARQUET
)
w.write(self._pandera_schema(python_type)(python_val))
remote_path = ctx.file_access.put_raw_data(local_dir)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type))))
else:
raise AssertionError(
f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"
)
[docs]
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandera.typing.DataFrame]
) -> pandera.typing.DataFrame:
if not (lv and lv.scalar and lv.scalar.schema):
raise AssertionError("Can only convert a literal schema to a pandera schema")
def downloader(x, y):
ctx.file_access.get_data(x, y, is_multipart=True)
df = FlyteSchema(
local_path=ctx.file_access.get_random_local_directory(),
remote_path=lv.scalar.schema.uri,
downloader=downloader,
supported_mode=SchemaOpenMode.READ,
)
return self._pandera_schema(expected_python_type)(df.open().all())
TypeEngine.register(PanderaTransformer())