This commit is contained in:
2022-09-28 21:11:28 +02:00
commit c68bd15bb8
98 changed files with 16302 additions and 0 deletions

0
ldm/data/__init__.py Normal file
View File

23
ldm/data/base.py Normal file
View File

@ -0,0 +1,23 @@
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass

394
ldm/data/imagenet.py Normal file
View File

@ -0,0 +1,394 @@
import os, yaml, pickle, shutil, tarfile, glob
import cv2
import albumentations
import PIL
import numpy as np
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from functools import partial
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
self._prepare_idx_to_synset()
self._prepare_human_to_integer_label()
self._load()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def _prepare(self):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
if syn in synsets:
files.append(rpath)
return files
else:
return relpaths
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
value, key = line.split(":")
self.human2integer_dict[key] = int(value)
def _load(self):
with open(self.txt_filelist, "r") as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
}
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
FILES = [
"ILSVRC2012_img_train.tar",
]
SIZES = [
147897477120,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.process_images = process_images
self.data_root = data_root
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
]
SIZES = [
6744924160,
1950000,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.data_root = data_root
self.process_images = process_images
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
for k, v in synset_dict.items():
src = os.path.join(datadir, k)
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
1. crops a crop of size s from image either as random or center crop
2. resizes crop to size with cv2.area_interpolation
3. degrades resized crop with degradation_fn
:param size: resizing to size after cropping
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
:param downscale_f: Low Resolution Downsample factor
:param min_crop_f: determines crop size s,
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
:param max_crop_f: ""
:param data_root:
:param random_crop:
"""
self.base = self.get_base()
assert size
assert (size / downscale_f).is_integer()
self.size = size
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
LR_image = self.degradation_process(image_pil)
LR_image = np.array(LR_image).astype(np.uint8)
else:
LR_image = self.degradation_process(image=image)["image"]
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
return example
class ImageNetSRTrain(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
return Subset(dset, indices)
class ImageNetSRValidation(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
return Subset(dset, indices)

92
ldm/data/lsun.py Normal file
View File

@ -0,0 +1,92 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)

115
ldm/data/personalized.py Normal file
View File

@ -0,0 +1,115 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
training_templates_smallest = [
'gijsbert {}',
]
reg_templates_smallest = [
'{}',
]
imagenet_templates_small = [
'{}',
]
imagenet_dual_templates_small = [
'{} with {}'
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg = False
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if not self.reg:
text = random.choice(training_templates_smallest).format(placeholder_string)
else:
text = random.choice(reg_templates_smallest).format(placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -0,0 +1,129 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
imagenet_templates_small = [
'a painting in the style of {}',
'a rendering in the style of {}',
'a cropped painting in the style of {}',
'the painting in the style of {}',
'a clean painting in the style of {}',
'a dirty painting in the style of {}',
'a dark painting in the style of {}',
'a picture in the style of {}',
'a cool painting in the style of {}',
'a close-up painting in the style of {}',
'a bright painting in the style of {}',
'a cropped painting in the style of {}',
'a good painting in the style of {}',
'a close-up painting in the style of {}',
'a rendition in the style of {}',
'a nice painting in the style of {}',
'a small painting in the style of {}',
'a weird painting in the style of {}',
'a large painting in the style of {}',
]
imagenet_dual_templates_small = [
'a painting in the style of {} with {}',
'a rendering in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'the painting in the style of {} with {}',
'a clean painting in the style of {} with {}',
'a dirty painting in the style of {} with {}',
'a dark painting in the style of {} with {}',
'a cool painting in the style of {} with {}',
'a close-up painting in the style of {} with {}',
'a bright painting in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'a good painting in the style of {} with {}',
'a painting of one {} in the style of {}',
'a nice painting in the style of {} with {}',
'a small painting in the style of {} with {}',
'a weird painting in the style of {} with {}',
'a large painting in the style of {} with {}',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example