Skip to content

Module: base_inference_workflow

Module implementing a base inference workflow.

This class is not meant to be subclassed directly; instead, subclass one of [TGIClientInferenceWorkflow, CSSInferenceWorkflow, BaseClassicInferenceWorkflow]

BaseInferenceWorkflow

Base class for an inference workflow

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
class BaseInferenceWorkflow(metaclass=abc.ABCMeta):
    """
    Base class for an inference workflow
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """
        Constructor. keeps track of arguments passed in.
        """
        super().__init__()
        self.args: list[Any] = list(args)
        self.kwargs: dict[Any, Any] = kwargs

        self.is_setup = False
        self.__inference_count: int = 0
        self.__proof_count: int = 0

    def setup(self) -> Any:
        """
        calls setup and keeps track of whether or not setup was called
        """
        self.is_setup = True
        return self.do_setup()

    @abc.abstractmethod
    def do_setup(self) -> Any:
        """set up your workflow here.
        For LLMs, this may be parameters like top_k, temperature
        for classical LLM, this may be model hyperparams.

        Returns: Any
        """

    def stream(self, input_data: Any) -> Iterator[Any]:
        """
        Stream data for inference. Subclasses should implement do_stream.
        """
        logging.info("preprocessing input_data %s", input_data)
        preprocessed_data = self.do_preprocessing(input_data)
        yield from self.do_stream(preprocessed_data)

    def inference(self, input_data: Any) -> Any:
        """performs inference. Checks that model is set up before
        performing inference.
        Subclasses should implement do_inference.

        Args:
            input_data (typing.Any): input from user

        Raises:
            ValueError: if setup not called beforehand

        Returns:
            Any: result of inference
        """
        if not self.is_setup:
            raise ValueError("setup not called before inference")

        logging.info("preprocessing input_data %s", input_data)
        preprocessed_data = self.do_preprocessing(input_data)

        logging.info("querying model with %s", preprocessed_data)
        model_output = self.do_run_model(preprocessed_data)

        logging.info("postprocessing model_output %s", model_output)
        self.__inference_count += 1
        return self.do_postprocessing(input_data, model_output)

    @abc.abstractmethod
    def do_run_model(self, preprocessed_data: Any) -> Any:
        """run model here. preprocessed_data type is
            left generic since depending on model type
        Args:
            preprocessed_data (typing.Any): preprocessed input into model

        Returns:
            typing.Any: result of running model
        """
        pass

    def do_preprocessing(self, input_data: Any) -> Any:
        """
        Implement any preprocessing of the raw user input. For example, you may need to
        apply feature engineering on the input before it is suitable for model inference.
        By default, this method returns the input data as is.

        Args:
            input_data (Any): raw input from user

        Returns:
            Any: preprocessed input
        """
        return input_data

    @abc.abstractmethod
    def do_stream(self, preprocessed_input: Any) -> Iterator[Any]:
        """
        Implement any streaming logic here. For example, you may
        want to stream data from OpenAI or any LLM. This method
        should return the data to be streamed.

        Returns:
            typing.Any: data to be streamed
        """

    def do_postprocessing(self, input_data: Any, output_data: Any) -> Any:
        """
        Implement any postprocessing of the model output. By default, this method
        returns the output data as is.

        Args:
            input_data (Any): raw input from user
            output_data (Any): model output

        Returns:
            Any: postprocessed output
        """
        return output_data

    @final
    def generate_proof(self) -> None:
        """
        Generates proof. checks that setup performed before hand.
        """
        if self.__inference_count <= self.__proof_count:
            logging.warning(
                "generated %s inferences only but "
                + "already generated %s. Possibly duplicate proof.",
                self.__inference_count,
                self.__proof_count,
            )

        self.do_generate_proof()
        self.__proof_count += 1

    def do_generate_proof(self) -> Any:
        """
        Generates proof, which may vary based on proving system. We currently
        only support EZKL based proving, which does not require proof
        generation to be defined as part of the inference workflow if
        the inference from the circuit directly used for proving. Indeed,
        that is the case for the classic infernet_ml proof service. However, this
        may change with the usage of optimistic and other eager or lazy
        proof systems, which we intend to support in the future. By default,
        will raise NotImplementedError. Override in subclass as needed.
        """
        raise NotImplementedError

__init__(*args, **kwargs)

Constructor. keeps track of arguments passed in.

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """
    Constructor. keeps track of arguments passed in.
    """
    super().__init__()
    self.args: list[Any] = list(args)
    self.kwargs: dict[Any, Any] = kwargs

    self.is_setup = False
    self.__inference_count: int = 0
    self.__proof_count: int = 0

do_generate_proof()

Generates proof, which may vary based on proving system. We currently only support EZKL based proving, which does not require proof generation to be defined as part of the inference workflow if the inference from the circuit directly used for proving. Indeed, that is the case for the classic infernet_ml proof service. However, this may change with the usage of optimistic and other eager or lazy proof systems, which we intend to support in the future. By default, will raise NotImplementedError. Override in subclass as needed.

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def do_generate_proof(self) -> Any:
    """
    Generates proof, which may vary based on proving system. We currently
    only support EZKL based proving, which does not require proof
    generation to be defined as part of the inference workflow if
    the inference from the circuit directly used for proving. Indeed,
    that is the case for the classic infernet_ml proof service. However, this
    may change with the usage of optimistic and other eager or lazy
    proof systems, which we intend to support in the future. By default,
    will raise NotImplementedError. Override in subclass as needed.
    """
    raise NotImplementedError

do_postprocessing(input_data, output_data)

Implement any postprocessing of the model output. By default, this method returns the output data as is.

Parameters:

Name Type Description Default
input_data Any

raw input from user

required
output_data Any

model output

required

Returns:

Name Type Description
Any Any

postprocessed output

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def do_postprocessing(self, input_data: Any, output_data: Any) -> Any:
    """
    Implement any postprocessing of the model output. By default, this method
    returns the output data as is.

    Args:
        input_data (Any): raw input from user
        output_data (Any): model output

    Returns:
        Any: postprocessed output
    """
    return output_data

do_preprocessing(input_data)

Implement any preprocessing of the raw user input. For example, you may need to apply feature engineering on the input before it is suitable for model inference. By default, this method returns the input data as is.

Parameters:

Name Type Description Default
input_data Any

raw input from user

required

Returns:

Name Type Description
Any Any

preprocessed input

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def do_preprocessing(self, input_data: Any) -> Any:
    """
    Implement any preprocessing of the raw user input. For example, you may need to
    apply feature engineering on the input before it is suitable for model inference.
    By default, this method returns the input data as is.

    Args:
        input_data (Any): raw input from user

    Returns:
        Any: preprocessed input
    """
    return input_data

do_run_model(preprocessed_data) abstractmethod

run model here. preprocessed_data type is left generic since depending on model type Args: preprocessed_data (typing.Any): preprocessed input into model

Returns:

Type Description
Any

typing.Any: result of running model

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
@abc.abstractmethod
def do_run_model(self, preprocessed_data: Any) -> Any:
    """run model here. preprocessed_data type is
        left generic since depending on model type
    Args:
        preprocessed_data (typing.Any): preprocessed input into model

    Returns:
        typing.Any: result of running model
    """
    pass

do_setup() abstractmethod

set up your workflow here. For LLMs, this may be parameters like top_k, temperature for classical LLM, this may be model hyperparams.

Returns: Any

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
@abc.abstractmethod
def do_setup(self) -> Any:
    """set up your workflow here.
    For LLMs, this may be parameters like top_k, temperature
    for classical LLM, this may be model hyperparams.

    Returns: Any
    """

do_stream(preprocessed_input) abstractmethod

Implement any streaming logic here. For example, you may want to stream data from OpenAI or any LLM. This method should return the data to be streamed.

Returns:

Type Description
Iterator[Any]

typing.Any: data to be streamed

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
@abc.abstractmethod
def do_stream(self, preprocessed_input: Any) -> Iterator[Any]:
    """
    Implement any streaming logic here. For example, you may
    want to stream data from OpenAI or any LLM. This method
    should return the data to be streamed.

    Returns:
        typing.Any: data to be streamed
    """

generate_proof()

Generates proof. checks that setup performed before hand.

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
@final
def generate_proof(self) -> None:
    """
    Generates proof. checks that setup performed before hand.
    """
    if self.__inference_count <= self.__proof_count:
        logging.warning(
            "generated %s inferences only but "
            + "already generated %s. Possibly duplicate proof.",
            self.__inference_count,
            self.__proof_count,
        )

    self.do_generate_proof()
    self.__proof_count += 1

inference(input_data)

performs inference. Checks that model is set up before performing inference. Subclasses should implement do_inference.

Parameters:

Name Type Description Default
input_data Any

input from user

required

Raises:

Type Description
ValueError

if setup not called beforehand

Returns:

Name Type Description
Any Any

result of inference

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def inference(self, input_data: Any) -> Any:
    """performs inference. Checks that model is set up before
    performing inference.
    Subclasses should implement do_inference.

    Args:
        input_data (typing.Any): input from user

    Raises:
        ValueError: if setup not called beforehand

    Returns:
        Any: result of inference
    """
    if not self.is_setup:
        raise ValueError("setup not called before inference")

    logging.info("preprocessing input_data %s", input_data)
    preprocessed_data = self.do_preprocessing(input_data)

    logging.info("querying model with %s", preprocessed_data)
    model_output = self.do_run_model(preprocessed_data)

    logging.info("postprocessing model_output %s", model_output)
    self.__inference_count += 1
    return self.do_postprocessing(input_data, model_output)

setup()

calls setup and keeps track of whether or not setup was called

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def setup(self) -> Any:
    """
    calls setup and keeps track of whether or not setup was called
    """
    self.is_setup = True
    return self.do_setup()

stream(input_data)

Stream data for inference. Subclasses should implement do_stream.

Source code in src/infernet_ml/workflows/inference/base_inference_workflow.py
def stream(self, input_data: Any) -> Iterator[Any]:
    """
    Stream data for inference. Subclasses should implement do_stream.
    """
    logging.info("preprocessing input_data %s", input_data)
    preprocessed_data = self.do_preprocessing(input_data)
    yield from self.do_stream(preprocessed_data)