112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
|
from ldm.modules.embedding_manager import EmbeddingManager
|
|
|
|
import argparse, os
|
|
from functools import partial
|
|
|
|
import torch
|
|
|
|
def get_placeholder_loop(placeholder_string, embedder, is_sd):
|
|
|
|
new_placeholder = None
|
|
|
|
while True:
|
|
if new_placeholder is None:
|
|
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
|
|
else:
|
|
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
|
|
|
|
token = get_clip_token_for_string(embedder.tokenizer, new_placeholder) if is_sd else get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
|
|
|
|
if token is not None:
|
|
return new_placeholder, token
|
|
|
|
def get_clip_token_for_string(tokenizer, string):
|
|
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
|
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
|
tokens = batch_encoding["input_ids"]
|
|
|
|
if torch.count_nonzero(tokens - 49407) == 2:
|
|
return tokens[0, 1]
|
|
|
|
return None
|
|
|
|
def get_bert_token_for_string(tokenizer, string):
|
|
token = tokenizer(string)
|
|
if torch.count_nonzero(token) == 3:
|
|
return token[0, 1]
|
|
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--manager_ckpts",
|
|
type=str,
|
|
nargs="+",
|
|
required=True,
|
|
help="Paths to a set of embedding managers to be merged."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
required=True,
|
|
help="Output path for the merged manager",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-sd", "--stable_diffusion",
|
|
action="store_true",
|
|
help="Flag to denote that we are merging stable diffusion embeddings"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.stable_diffusion:
|
|
embedder = FrozenCLIPEmbedder().cuda()
|
|
else:
|
|
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
|
|
|
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
|
|
|
|
string_to_token_dict = {}
|
|
string_to_param_dict = torch.nn.ParameterDict()
|
|
|
|
placeholder_to_src = {}
|
|
|
|
for manager_ckpt in args.manager_ckpts:
|
|
print(f"Parsing {manager_ckpt}...")
|
|
|
|
manager = EmbeddingManager()
|
|
manager.load(manager_ckpt)
|
|
|
|
for placeholder_string in manager.string_to_token_dict:
|
|
if not placeholder_string in string_to_token_dict:
|
|
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
|
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
|
|
|
placeholder_to_src[placeholder_string] = manager_ckpt
|
|
else:
|
|
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, is_sd=args.stable_diffusion)
|
|
string_to_token_dict[new_placeholder] = new_token
|
|
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
|
|
|
placeholder_to_src[new_placeholder] = manager_ckpt
|
|
|
|
print("Saving combined manager...")
|
|
merged_manager = EmbeddingManager()
|
|
merged_manager.string_to_param_dict = string_to_param_dict
|
|
merged_manager.string_to_token_dict = string_to_token_dict
|
|
merged_manager.save(args.output_path)
|
|
|
|
print("Managers merged. Final list of placeholders: ")
|
|
print(placeholder_to_src)
|
|
|
|
|
|
|
|
|