VAE
Collection
3 items
•
Updated
Original Model Link : https://huggingface.co/fal/AuraEquiVAE
name: AuraEquiVAE-SAFETENSORS
base_model: fal/AuraEquiVAE
license: apache-2.0
pipeline_tag: feature-extraction
tasks:
- feature-extraction
- image-to-image
language: en
AuraEquiVAE-SAFETENSORS is an experimental, equivariant VAE converted to safetensors format, where an image and a mirror of that image are interpreted identically (good thing).
I wanted to try it ¯\(ツ)/¯
Class code is included and adapted from original repo, which is itself adapted from black-forest-labs Flux1 AE. To use AuraEquiVAE:
from ae import VAE # File included in repo
from safetensors.torch import load_file
import torch
import torchvision.transforms
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
img_orig = Image.open("cat.jpg").convert("RGB")
img = transforms.ToTensor()(img_orig).unsqueeze(0).to(device)
img = (img - 0.5) / 0.5
vae = VAE(
resolution=256,
in_channels=3,
ch=256,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
).to(device).bfloat16()
state_dict = load_file(vae_path)
vae.load_state_dict(state_dict)
with torch.no_grad():
z = vae.encoder(img)
z = z.clamp(-8.0, 8.0) # this is latent!!
...
with torch.no_grad():
decz = vae.decoder(z) # this is image!
decimg = ((decz + 1) / 2).clamp(0, 1).squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
decimg = (decimg * 255).astype('uint8')
decimg = Image.fromarray(decimg) # PIL imag
Base model
fal/AuraEquiVAE