{ "cells": [ { "cell_type": "markdown", "id": "aa2c1ada", "metadata": { "id": "aa2c1ada" }, "source": [ "# Dreambooth\n", "### Notebook implementation by Joe Penna (@MysteryGuitarM on Twitter) - Improvements by David Bielejeski\n", "https://github.com/JoePenna/Dreambooth-Stable-Diffusion\n", "\n", "### If on runpod / vast.ai / etc, spin up an A6000 or A100 pod using a Stable Diffusion template with Jupyter pre-installed." ] }, { "cell_type": "markdown", "id": "7b971cc0", "metadata": { "id": "7b971cc0" }, "source": [ "## Build Environment" ] }, { "cell_type": "code", "execution_count": null, "id": "9e1bc458-091b-42f4-a125-c3f0df20f29d", "metadata": { "id": "9e1bc458-091b-42f4-a125-c3f0df20f29d", "scrolled": true }, "outputs": [], "source": [ "#BUILD ENV\n", "!pip install omegaconf\n", "!pip install einops\n", "!pip install pytorch-lightning==1.6.5\n", "!pip install test-tube\n", "!pip install transformers\n", "!pip install kornia\n", "!pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers\n", "!pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip\n", "!pip install setuptools==59.5.0\n", "!pip install pillow==9.0.1\n", "!pip install torchmetrics==0.6.0\n", "!pip install -e .\n", "!pip install protobuf==3.20.1\n", "!pip install gdown\n", "!pip install pydrive\n", "!pip install -qq diffusers[\"training\"]==0.3.0 transformers ftfy\n", "!pip install -qq \"ipywidgets>=7,<8\"\n", "!pip install huggingface_hub\n", "!pip install ipywidgets" ] }, { "cell_type": "code", "execution_count": null, "id": "ddf7a43d", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "## Move the sd-v1-4.ckpt to the root of this directory as \"model.ckpt\"\n", "#actual_locations_of_model_blob = !readlink -f {downloaded_model_path}\n", "#!cp {actual_locations_of_model_blob[-1]} model.ckpt\n", "!wget 'https://prodesk.home.thijn.ovh/sd-v1-4.ckpt'\n", "!cp sd-v1-4.ckpt model.ckpt" ] }, { "cell_type": "markdown", "id": "mxPL2O0OLvBW", "metadata": { "id": "mxPL2O0OLvBW" }, "source": [ "## Download pre-generated regularization images\n", "\n", "We've created the following image sets\n", "\n", "* man_euler - provided by Niko Pueringer (Corridor Digital) - euler @ 40 steps, CFG 7.5\n", "* man_unsplash - pictures from various photographers\n", "* person_ddim\n", "* woman_ddim - provided by David Bielejeski - ddim @ 50 steps, CFG 10.0\n", "\n", "`person_ddim` is recommended" ] }, { "cell_type": "code", "execution_count": null, "id": "e7EydXCjOV1v", "metadata": { "id": "e7EydXCjOV1v" }, "outputs": [], "source": [ "# Grab the existing regularization images\n", "# Choose the dataset that best represents what you are trying to do and matches what you used for your token\n", "# man_euler, man_unsplash, person_ddim, woman_ddim\n", "dataset=\"person_ddim\"\n", "!rm -rf Stable-Diffusion-Regularization-Images-{dataset}\n", "!git clone https://github.com/djbielejeski/Stable-Diffusion-Regularization-Images-{dataset}.git\n", "\n", "!mkdir -p outputs/txt2img-samples/samples/{dataset}\n", "!mv -v Stable-Diffusion-Regularization-Images-{dataset}/{dataset}/*.* outputs/txt2img-samples/samples/{dataset}" ] }, { "cell_type": "markdown", "id": "zshrC_JuMXmM", "metadata": { "id": "zshrC_JuMXmM" }, "source": [ "# Upload your training images\n", "Upload 10-20 images of someone to\n", "\n", "```\n", "/workspace/Dreambooth-Stable-Diffusion/training_samples\n", "```\n", "\n", "WARNING: Be sure to upload an *even* amount of images, otherwise the training inexplicably stops at 1500 steps.\n", "\n", "* 2-3 full body\n", "* 3-5 upper body \n", "* 5-12 close-up on face" ] }, { "cell_type": "code", "execution_count": null, "id": "60e37ee0", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "#@markdown Add here the URLs to the images of the concept you are adding\n", "urls = [\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121054.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121057.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121100.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121102.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121104.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121106.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121108.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121112.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121116.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121118.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121120.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121153.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121155.png\",\n", "\"https://prodesk.home.thijn.ovh/benji/IMG_20221011_121157.png\"\n", " ## You can add additional images here\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "df314e2e", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "#@title Download and check the images you have just added\n", "import os\n", "import requests\n", "from io import BytesIO\n", "from PIL import Image\n", "\n", "\n", "def image_grid(imgs, rows, cols):\n", " assert len(imgs) == rows*cols\n", "\n", " w, h = imgs[0].size\n", " grid = Image.new('RGB', size=(cols*w, rows*h))\n", " grid_w, grid_h = grid.size\n", "\n", " for i, img in enumerate(imgs):\n", " grid.paste(img, box=(i%cols*w, i//cols*h))\n", " return grid\n", "\n", "def download_image(url):\n", " try:\n", " response = requests.get(url)\n", " except:\n", " return None\n", " return Image.open(BytesIO(response.content)).convert(\"RGB\")\n", "\n", "images = list(filter(None,[download_image(url) for url in urls]))\n", "save_path = \"./training_samples\"\n", "if not os.path.exists(save_path):\n", " os.mkdir(save_path)\n", "[image.save(f\"{save_path}/{i}.png\", format=\"png\") for i, image in enumerate(images)]\n" ] }, { "cell_type": "markdown", "id": "ad4e50df", "metadata": { "id": "ad4e50df" }, "source": [ "## Training\n", "\n", "If training a person or subject, keep an eye on your project's `logs/{folder}/images/train/samples_scaled_gs-00xxxx` generations.\n", "\n", "If training a style, keep an eye on your project's `logs/{folder}/images/train/samples_gs-00xxxx` generations." ] }, { "cell_type": "code", "execution_count": null, "id": "6fa5dd66-2ca0-4819-907e-802e25583ae6", "metadata": { "id": "6fa5dd66-2ca0-4819-907e-802e25583ae6", "tags": [] }, "outputs": [], "source": [ "# START THE TRAINING\n", "project_name = \"benjiman\"\n", "batch_size = 1000\n", "class_word = \"person\" # << match this word to the class word from regularization images above\n", "reg_data_root = \"/workspace/Dreambooth-Stable-Diffusion/outputs/txt2img-samples/samples/\" + dataset\n", "\n", "!rm -rf training_samples/.ipynb_checkpoints\n", "!python \"main.py\" \\\n", " --base configs/stable-diffusion/v1-finetune_unfrozen.yaml \\\n", " -t \\\n", " --actual_resume \"model.ckpt\" \\\n", " --reg_data_root {reg_data_root} \\\n", " -n {project_name} \\\n", " --gpus 0, \\\n", " --data_root \"/workspace/Dreambooth-Stable-Diffusion/training_samples\" \\\n", " --batch_size {batch_size} \\\n", " --class_word class_word" ] }, { "cell_type": "markdown", "id": "dc49d0bd", "metadata": {}, "source": [ "## Pruning (12GB to 2GB)\n", "We are working on having this happen automatically (TODO: PR's welcome)" ] }, { "cell_type": "code", "execution_count": null, "id": "27cea333", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "directory_paths = !ls -d logs/*" ] }, { "cell_type": "code", "execution_count": null, "id": "965b4654", "metadata": { "tags": [] }, "outputs": [], "source": [ "# This version should automatically prune around 10GB from the ckpt file\n", "last_checkpoint_file = directory_paths[-1] + \"/checkpoints/last.ckpt\"\n", "!python \"scripts/prune-ckpt.py\" --ckpt {last_checkpoint_file}" ] }, { "cell_type": "code", "execution_count": null, "id": "b7a8cec3", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "last_checkpoint_file_pruned = directory_paths[-1] + \"/checkpoints/last-pruned.ckpt\"\n", "training_samples = !ls training_samples\n", "date_string = !date +\"%Y-%m-%dT%H-%M-%S\"\n", "file_name = date_string[-1] + \"_\" + project_name + \"_\" + str(len(training_samples)) + \"_training_images_\" + str(batch_size) + \"_batch_size_\" + class_word + \"_class_word.ckpt\"\n", "!mkdir -p trained_models\n", "!mv {last_checkpoint_file_pruned} trained_models/{file_name}" ] }, { "cell_type": "code", "execution_count": null, "id": "ff1a46d9", "metadata": {}, "outputs": [], "source": [ "# Download your trained model file from `trained_models` and use in your favorite Stable Diffusion repo!\n", "!wget -nc 'https://prodesk.home.thijn.ovh/gijs/ai'\n", "!chmod 600 ai\n", "!scp -r -i ai -P 2387 -o StrictHostKeyChecking=no ./trained_models ai@home.thijn.ovh:/mnt/hdd/ai/" ] }, { "cell_type": "markdown", "id": "d28d0139", "metadata": {}, "source": [ "## Generate Images With Your Trained Model!" ] }, { "cell_type": "code", "execution_count": null, "id": "80ddb03b", "metadata": {}, "outputs": [], "source": [ "!python scripts/stable_txt2img.py \\\n", " --ddim_eta 0.0 \\\n", " --n_samples 1 \\\n", " --n_iter 10 \\\n", " --scale 7.0 \\\n", " --ddim_steps 50 \\\n", " --ckpt \"/workspace/Dreambooth-Stable-Diffusion/trained_models/CHANGEME.ckpt\" \\\n", " --prompt \"beautiful oil painting of benjiman with high detail of a gorgeous wood-elf ranger from dungeons and dragons in the desert by artgerm and greg rutkowski and thomas kinkade\"" ] }, { "cell_type": "code", "execution_count": null, "id": "0e3c10d9-2c40-4f50-9cf4-97e88a57288c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" } } }, "nbformat": 4, "nbformat_minor": 5 }