notebook-examples/text_to_image.ipynb

346 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "e355f137",
"metadata": {},
"source": [
"*this notebook requires a working PyTorch GPU environment* "
]
},
{
"cell_type": "markdown",
"id": "e04b5280",
"metadata": {},
"source": [
"# Stable Diffusion Text to Image model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1f417fcf-52d5-4e49-9e26-3834eba323cf",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: transformers in /opt/conda/lib/python3.10/site-packages (4.33.2)\n",
"Requirement already satisfied: diffusers in /opt/conda/lib/python3.10/site-packages (0.21.2)\n",
"Requirement already satisfied: accelerate in /opt/conda/lib/python3.10/site-packages (0.23.0)\n",
"Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.10/site-packages (8.0.7)\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers) (3.12.2)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.17.3)\n",
"Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (1.24.4)\n",
"Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (6.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers) (2023.8.8)\n",
"Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from transformers) (0.3.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.10/site-packages (from diffusers) (6.7.0)\n",
"Requirement already satisfied: Pillow in /opt/conda/lib/python3.10/site-packages (from diffusers) (10.0.0)\n",
"Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate) (5.9.5)\n",
"Requirement already satisfied: torch>=1.10.0 in /opt/conda/lib/python3.10/site-packages (from accelerate) (2.0.0.post200)\n",
"Requirement already satisfied: ipykernel>=4.5.1 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (6.24.0)\n",
"Requirement already satisfied: ipython>=6.1.0 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (8.14.0)\n",
"Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (5.9.0)\n",
"Requirement already satisfied: widgetsnbextension~=4.0.7 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (4.0.8)\n",
"Requirement already satisfied: jupyterlab-widgets~=3.0.7 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (3.0.8)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (2023.6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.7.1)\n",
"Requirement already satisfied: comm>=0.1.1 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.3)\n",
"Requirement already satisfied: debugpy>=1.6.5 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.6.7)\n",
"Requirement already satisfied: jupyter-client>=6.1.12 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (8.3.0)\n",
"Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (5.3.1)\n",
"Requirement already satisfied: matplotlib-inline>=0.1 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.6)\n",
"Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.5.6)\n",
"Requirement already satisfied: pyzmq>=20 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (25.1.0)\n",
"Requirement already satisfied: tornado>=6.1 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.3.2)\n",
"Requirement already satisfied: backcall in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.2.0)\n",
"Requirement already satisfied: decorator in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n",
"Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.18.2)\n",
"Requirement already satisfied: pickleshare in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.7.5)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.38)\n",
"Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.15.1)\n",
"Requirement already satisfied: stack-data in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.2)\n",
"Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.8.0)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.2)\n",
"Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.15.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (2.0.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers) (2023.5.7)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (2.8.2)\n",
"Requirement already satisfied: platformdirs>=2.5 in /opt/conda/lib/python3.10/site-packages (from jupyter-core!=5.0.*,>=4.12->ipykernel>=4.5.1->ipywidgets) (3.8.0)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /opt/conda/lib/python3.10/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.1.0->ipywidgets) (0.2.6)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n",
"Requirement already satisfied: executing>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (1.2.0)\n",
"Requirement already satisfied: asttokens>=2.1.0 in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.2.1)\n",
"Requirement already satisfied: pure-eval in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n",
"Requirement already satisfied: six in /opt/conda/lib/python3.10/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n"
]
}
],
"source": [
"!pip install transformers diffusers accelerate ipywidgets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6814d202-9131-4170-a967-6527a504eff4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-27 13:15:44.716139: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"import torch\n",
"from diffusers import StableDiffusionPipeline\n",
"from matplotlib import pyplot as plt\n",
"import datetime"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9a48b8ed-6f07-4f27-95eb-7ec45c4a7d2f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# diffusers docs: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#texttoimage-generation\n",
"# model_id = \"dreamlike-art/dreamlike-diffusion-1.0\"\n",
"# model_id = \"prompthero/openjourney\"\n",
"model_id = \"XpucT/Deliberate\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "895b7852-6e82-4d74-bada-4e5b92ef90f4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"text_encoder/model.safetensors not found\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "369649662a124d3da6cbd06fbae927d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/6 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.\n",
" warnings.warn(\n",
"You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\n"
]
}
],
"source": [
"pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)\n",
"device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
"pipe = pipe.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2b343ab6-7c44-4a19-a4ee-e5cf24dc60ca",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "095dd8981c254c0ba98197b39b78e38b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Textarea(value='', layout=Layout(height='50px', width='auto'), placeholder='cats, sharks, ships…"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"<Figure size 160x160 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import ipywidgets as widgets\n",
"from IPython.display import clear_output, display\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import set_matplotlib_formats\n",
"%matplotlib inline\n",
"#set_matplotlib_formats('svg')\n",
"\n",
"np.random.seed(42)\n",
"\n",
"#prompt = \"Renaissance-style portrait of an astronaut in space, detailed starry background, reflective helmet, matte painting, hyperdetailed, CGSociety\"\n",
"#negative_prompt = \"gold\"\n",
"#textarea\n",
"std_layout = widgets.Layout(width=\"auto\")\n",
"\n",
"prompt_box = widgets.Textarea(value='', placeholder='cats, sharks, ships, underwater, cinematic composition', description='', layout=widgets.Layout(height=\"50px\", width=\"auto\"))\n",
"neg_prompt_box = widgets.Textarea(value='', placeholder='negative prompt', description='', layout=widgets.Layout(height=\"30px\", width=\"auto\"))\n",
"#slider \n",
"img2gen_lab = widgets.Label(value=\"Images to Generate:\", layout=std_layout)\n",
"img2gen_var = widgets.IntSlider(value=1, min=1, max=10, layout=std_layout)\n",
"#input\n",
"size_x_lab = widgets.Label(value=\"Image Width:\", layout=std_layout)\n",
"size_x_var = widgets.BoundedIntText(value=32, min=32, max=512, step=8, layout=std_layout)\n",
"size_y_lab = widgets.Label(value=\"Image heigth:\", layout=std_layout)\n",
"size_y_var = widgets.BoundedIntText(value=32, min=32, max=512, step=8, layout=std_layout)\n",
"steps_lab = widgets.Label(value=\"Steps:\", layout=std_layout)\n",
"steps_var = widgets.BoundedIntText(value=50, min=11, max=100, step=1, layout=std_layout)\n",
"#Button\n",
"btn = widgets.Button(description=\"Generate\", tooltip=\"Click me\",\n",
" layout=widgets.Layout(width=\"auto\", height=\"30px\", margin=\"30px\"))\n",
"#Canvas\n",
"fig = plt.figure(figsize=(128/80, 128/80))\n",
"box = widgets.VBox([prompt_box, neg_prompt_box, img2gen_lab, img2gen_var, steps_lab, steps_var, size_x_lab, size_x_var,size_y_lab, size_y_var, btn])\n",
"box2 = widgets.VBox([prompt_box, neg_prompt_box, img2gen_lab, img2gen_var, steps_lab, steps_var, size_x_lab, size_x_var,size_y_lab, size_y_var, btn])\n",
"\n",
"box2"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "71055e73-a072-41d7-84c6-602fa69ff46f",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb2114daea2b446ca0b07565f6ce88ba",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = widgets.Output()\n",
"out"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d2645fdf-4bed-43f1-b3d1-f06c174d45ad",
"metadata": {},
"outputs": [],
"source": [
"def on_value_change(change):\n",
" prompt = prompt_box.value\n",
" negative_prompt = neg_prompt_box.value\n",
" images_to_generate = int(img2gen_var.value)\n",
" steps = steps_var.value\n",
" width = size_x_var.value\n",
" height = size_y_var.value\n",
" seed = None\n",
" guidance = 7.5 # Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.\n",
" with out:\n",
" clear_output(wait=True)\n",
" for _ in range(images_to_generate):\n",
" steps = int(steps)\n",
" width = int(width)\n",
" height = int(height)\n",
" current_seed = seed or torch.randint(0, int(1e5), size=(1, 1))[0].item()\n",
" generator = torch.Generator().manual_seed(int(current_seed))\n",
" img = pipe(prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=steps, guidance_scale=guidance, generator=generator).images[0]\n",
" time_now = datetime.datetime.now().strftime(\"%y.%m.%d_%H:%M:%S\")\n",
" plt.imshow(img)\n",
" plt.axis(\"off\")\n",
" print(\"Current Seed:\", current_seed)\n",
" plt.title(\"\")\n",
" plt.show()\n",
" # img.save(f\"./{time_now}_{current_seed}.jpg\")\n",
"\n",
"btn.on_click(on_value_change)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "801c01c7-d364-411a-ba10-25161e47f65f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}