89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
import argparse, os, sys, glob
|
|
|
|
sys.path.append(os.path.join(sys.path[0], '..'))
|
|
|
|
import torch
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
from PIL import Image
|
|
from tqdm import tqdm, trange
|
|
from einops import rearrange
|
|
from torchvision.utils import make_grid
|
|
|
|
from ldm.util import instantiate_from_config
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from ldm.models.diffusion.plms import PLMSSampler
|
|
from ldm.data.personalized import PersonalizedBase
|
|
from evaluation.clip_eval import LDMCLIPEvaluator
|
|
|
|
def load_model_from_config(config, ckpt, verbose=False):
|
|
print(f"Loading model from {ckpt}")
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
sd = pl_sd["state_dict"]
|
|
model = instantiate_from_config(config.model)
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
if len(m) > 0 and verbose:
|
|
print("missing keys:")
|
|
print(m)
|
|
if len(u) > 0 and verbose:
|
|
print("unexpected keys:")
|
|
print(u)
|
|
|
|
model.cuda()
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
nargs="?",
|
|
default="a painting of a virus monster playing guitar",
|
|
help="the prompt to render"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--ckpt_path",
|
|
type=str,
|
|
default="/data/pretrained_models/ldm/text2img-large/model.ckpt",
|
|
help="Path to pretrained ldm text2img model")
|
|
|
|
parser.add_argument(
|
|
"--embedding_path",
|
|
type=str,
|
|
help="Path to a pre-trained embedding manager checkpoint")
|
|
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
type=str,
|
|
help="Path to directory with images used to train the embedding vectors"
|
|
)
|
|
|
|
opt = parser.parse_args()
|
|
|
|
|
|
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
|
|
model = load_model_from_config(config, opt.ckpt_path) # TODO: check path
|
|
model.embedding_manager.load(opt.embedding_path)
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
model = model.to(device)
|
|
|
|
evaluator = LDMCLIPEvaluator(device)
|
|
|
|
prompt = opt.prompt
|
|
|
|
data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0)
|
|
|
|
images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)]
|
|
images = torch.stack(images, axis=0)
|
|
|
|
sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt)
|
|
|
|
output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-"))
|
|
|
|
print("Image similarity: ", sim_img)
|
|
print("Text similarity: ", sim_text) |