UDF¶
User-defined functions (UDFs) can run batch processing on a chain to generate new chain values. The UDF will take fields from one or more rows of the data and output new fields. A UDF can run at scale on multiple workers and processes.
A UDF can be any Python function. The classes below are useful to implement a "stateful"
UDF where a function is insufficient, such as when additional setup()
or teardown()
steps need to happen before or after the processing function runs.
UDFBase
¶
Bases: AbstractUDF
Base class for stateful user-defined functions.
Any class that inherits from it must have a process()
method that takes input
params from one or more rows in the chain and produces the expected output.
Optionally, the class may include these methods:
- setup()
to run code on each worker before process()
is called.
- teardown()
to run code on each worker after process()
completes.
Example
from datachain import C, DataChain, Mapper
import open_clip
class ImageEncoder(Mapper):
def __init__(self, model_name: str, pretrained: str):
self.model_name = model_name
self.pretrained = pretrained
def setup(self):
self.model, _, self.preprocess = (
open_clip.create_model_and_transforms(
self.model_name, self.pretrained
)
)
def process(self, file) -> list[float]:
img = file.get_value()
img = self.preprocess(img).unsqueeze(0)
emb = self.model.encode_image(img)
return emb[0].tolist()
(
DataChain.from_storage(
"gs://datachain-demo/fashion-product-images/images", type="image"
)
.limit(5)
.map(
ImageEncoder("ViT-B-32", "laion2b_s34b_b79k"),
params=["file"],
output={"emb": list[float]},
)
.show()
)
Source code in datachain/lib/udf.py
process
¶
Processing function that needs to be defined by user
setup
¶
Initialization process executed on each worker before processing begins. This is needed for tasks like pre-loading ML models prior to scoring.
teardown
¶
Teardown process executed on each process/worker after processing ends. This is needed for tasks like closing connections to end-points.
Aggregator
¶
Bases: UDFBase
Inherit from this class to pass to DataChain.agg()
.
Source code in datachain/lib/udf.py
BatchMapper
¶
Bases: UDFBase
Inherit from this class to pass to DataChain.batch_map()
.
Source code in datachain/lib/udf.py
Generator
¶
Bases: UDFBase
Inherit from this class to pass to DataChain.gen()
.