Module: diffusion_utils
VaeConfig
Bases: BaseModel
Configuration class for VAE (Variational Autoencoder).
Attributes:
Name | Type | Description |
---|---|---|
type |
VaeType
|
The type of VAE. |
model |
str
|
The model name. |
torch_dtype |
Any
|
The torch data type. |
Source code in src/infernet_ml/utils/diffusion_utils.py
get_torch_dtype(dtype)
Returns the corresponding torch.dtype based on the input string dtype.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dtype |
str
|
The input string representing the desired dtype. |
required |
Returns:
Type | Description |
---|---|
dtype
|
torch.dtype: The corresponding torch.dtype. |
Examples:
Source code in src/infernet_ml/utils/diffusion_utils.py
get_vae_kl(config)
Retrieves a Variational Autoencoder (VAE) model based on the given configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
VaeConfig
|
The configuration object specifying the type of VAE model to retrieve. |
required |
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
The VAE model. |
Examples: