Skip to content

Module: model_loader

Model Loader: depending on the model source, load the model from the local file system, Hugging Face Hub, or Arweave.

ArweaveLoadArgs

Bases: CommonLoadArgs

Arguments for loading the model

Source code in src/infernet_ml/utils/model_loader.py
class ArweaveLoadArgs(CommonLoadArgs):
    """
    Arguments for loading the model
    """

    pass

CommonLoadArgs

Bases: BaseModel

Common arguments for loading a model

Source code in src/infernet_ml/utils/model_loader.py
class CommonLoadArgs(BaseModel):
    """
    Common arguments for loading a model
    """

    model_config = ConfigDict(frozen=True)

    cache_path: Optional[str] = None
    version: Optional[str] = None
    repo_id: str
    filename: str

HFLoadArgs

Bases: CommonLoadArgs

Arguments for loading the model

Source code in src/infernet_ml/utils/model_loader.py
class HFLoadArgs(CommonLoadArgs):
    """
    Arguments for loading the model
    """

    pass

LocalLoadArgs

Bases: BaseModel

Arguments for loading the model

Source code in src/infernet_ml/utils/model_loader.py
class LocalLoadArgs(BaseModel):
    """
    Arguments for loading the model
    """

    model_config = ConfigDict(frozen=True)

    path: str

ModelSource

Bases: IntEnum

Enum for the model source

  • LOCAL: Load the model from the local file system
  • ARWEAVE: Load the model from Arweave
  • HUGGINGFACE_HUB: Load the model from Hugging Face Hub
Source code in src/infernet_ml/utils/model_loader.py
class ModelSource(IntEnum):
    """
    Enum for the model source

    - `LOCAL`: Load the model from the local file system
    - `ARWEAVE`: Load the model from Arweave
    - `HUGGINGFACE_HUB`: Load the model from Hugging Face Hub
    """

    LOCAL = 0
    ARWEAVE = 1
    HUGGINGFACE_HUB = 2

download_model(model_source, load_args)

Load the model from the specified source.

Parameters:

Name Type Description Default
model_source ModelSource

the source of the model

required
load_args LoadArgs

the load arguments, options are: - LocalLoadArgs - HFLoadArgs - ArweaveLoadArgs

required

Returns:

Name Type Description
str str

the path to the model

Source code in src/infernet_ml/utils/model_loader.py
def download_model(
    model_source: ModelSource,
    load_args: LoadArgs,
) -> str:
    """
    Load the model from the specified source.

    Args:
        model_source (ModelSource): the source of the model
        load_args (LoadArgs): the load arguments, options are:
            - LocalLoadArgs
            - HFLoadArgs
            - ArweaveLoadArgs

    Returns:
        str: the path to the model
    """
    logger.info(f"Downloading model from {model_source} with args {load_args}")

    match model_source:
        # load the model locally
        case ModelSource.LOCAL:
            local_args = cast(LocalLoadArgs, load_args)
            return local_args.path
        case ModelSource.HUGGINGFACE_HUB:
            hf_args = cast(HFLoadArgs, load_args)
            return cast(
                str,
                hf_hub_download(
                    repo_id=hf_args.repo_id,
                    filename=hf_args.filename,
                    revision=hf_args.version,
                    cache_dir=hf_args.cache_path,
                ),
            )
        case ModelSource.ARWEAVE:
            arweave_args = cast(ArweaveLoadArgs, load_args)
            cache_path = arweave_args.cache_path or os.path.expanduser("~/.cache/")
            version = arweave_args.version or "latest"
            base_path = f"{cache_path}/{arweave_args.repo_id}/{version}/."
            logging.info(
                f"Downloading model from Arweave "
                f"{cache_path}/{arweave_args.repo_id}/{version}/{arweave_args.filename}"
            )
            return RepoManager().download_artifact_file(
                repo_id=arweave_args.repo_id,
                file_name=arweave_args.filename,
                version=arweave_args.version,
                base_path=base_path,
            )
        case _:
            raise ValueError(f"Invalid model source {model_source}")

parse_load_args(model_source, config)

Parse the load arguments based on the model source.

Parameters:

Name Type Description Default
model_source ModelSource

the source of the model

required
config dict[str, str]

the configuration

required

Returns:

Name Type Description
LoadArgs LoadArgs

the load arguments

Raises:

Type Description
ValueError

if the model source is invalid

Source code in src/infernet_ml/utils/model_loader.py
def parse_load_args(model_source: ModelSource, config: Any) -> LoadArgs:
    """
    Parse the load arguments based on the model source.

    Args:
        model_source (ModelSource): the source of the model
        config (dict[str, str]): the configuration

    Returns:
        LoadArgs: the load arguments

    Raises:
        ValueError: if the model source is invalid
    """

    match model_source:
        # parse the load arguments for the local model
        case ModelSource.LOCAL:
            return LocalLoadArgs(path=config["model_path"])
        # parse the load arguments for the model from Hugging Face Hub
        case ModelSource.HUGGINGFACE_HUB:
            return HFLoadArgs(repo_id=config["repo_id"], filename=config["filename"])
        # parse the load arguments for the model from Arweave
        case ModelSource.ARWEAVE:
            return ArweaveLoadArgs(
                repo_id=config["repo_id"],
                filename=config["filename"],
                version=config.get("version"),
            )
        case _:
            raise ValueError(f"Invalid model source {model_source}")