diff --git a/lit_nlp/examples/dalle_mini/README.md b/lit_nlp/examples/dalle_mini/README.md new file mode 100644 index 00000000..957cb1cf --- /dev/null +++ b/lit_nlp/examples/dalle_mini/README.md @@ -0,0 +1,27 @@ +Dalle_Mini Demo for the Learning Interpretability Tool +======================================================= + +This demo showcases how LIT can be used in text-to-image generation mode. It is +based on the mini-dalle Mini model +(https://www.piwheels.org/project/dalle-mini/). + +You will need a standalone virtual environment for the Python libraries, which +you can set up using the following commands from the root of the LIT repo. + +```sh +# Create the virtual environment. You may want to use python3 or python3.10 +# depends on how many Python versions you have installed and their aliases. +python -m venv .dalle-mini +source .dalle-mini/bin/activate +# This requirements.txt file will also install the core LIT library deps. +pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt +# The LIT web app still needs to be built in the usual way. +(cd ./lit_nlp && yarn && yarn build) +``` + +Once your virtual environment is setup, you can launch the demo with the +following command. + +```sh +python -m lit_nlp.examples.dalle_mini.demo +``` \ No newline at end of file diff --git a/lit_nlp/examples/dalle_mini/data.py b/lit_nlp/examples/dalle_mini/data.py new file mode 100644 index 00000000..e54b1ca3 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/data.py @@ -0,0 +1,18 @@ +"""Data loaders for dalle-mini model.""" + +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import types as lit_types + + +class DallePrompts(lit_dataset.Dataset): + + def __init__(self, prompts: list[str]): + self.examples = [] + for prompt in prompts: + self.examples.append({"prompt": prompt}) + + def spec(self) -> lit_types.Spec: + return {"prompt": lit_types.TextSegment()} + + def __iter__(self): + return iter(self.examples) diff --git a/lit_nlp/examples/dalle_mini/demo.py b/lit_nlp/examples/dalle_mini/demo.py new file mode 100644 index 00000000..18cbc885 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/demo.py @@ -0,0 +1,98 @@ +r"""Example for dalle-mini demo model. + +To run locally with a small number of examples: + python -m lit_nlp.examples.dalle_mini.demo + + +Then navigate to localhost:5432 to access the demo UI. +""" + +from collections.abc import Sequence +import sys +from typing import Optional + +from absl import app +from absl import flags +from lit_nlp import app as lit_app +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.api import layout +from lit_nlp.examples.dalle_mini import data as dalle_data +from lit_nlp.examples.dalle_mini import model as dalle_model + + +# NOTE: additional flags defined in server_flags.py +_FLAGS = flags.FLAGS +_FLAGS.set_default("development_demo", True) +_FLAGS.set_default("default_layout", "DALLE_LAYOUT") + +_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.") + +_MODELS = (["dalle-mini"],) + +_CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"] + +# Custom frontend layout; see api/layout.py +_modules = layout.LitModuleName +_DALLE_LAYOUT = layout.LitCanonicalLayout( + upper={ + "Main": [ + _modules.DataTableModule, + _modules.DatapointEditorModule, + ] + }, + lower={ + "Predictions": [ + _modules.GeneratedImageModule, + _modules.GeneratedTextModule, + ], + }, + description="Custom layout for Text to Image models.", +) + + +CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"DALLE_LAYOUT": _DALLE_LAYOUT} + + +def get_wsgi_app() -> Optional[dev_server.LitServerType]: + _FLAGS.set_default("server_type", "external") + _FLAGS.set_default("demo_mode", True) + # Parse flags without calling app.run(main), to avoid conflict with + # gunicorn command line flags. + unused = _FLAGS(sys.argv, known_only=True) + return main(unused) + + +def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + # Load models, according to the --models flag. + models = {} + + model_loaders: lit_app.ModelLoadersMap = {} + model_loaders["dalle-mini"] = ( + dalle_model.DalleMiniModel, + dalle_model.DalleMiniModel.init_spec(), + ) + + datasets = {"examples": dalle_data.DallePrompts(_CANNED_PROMPTS)} + dataset_loaders: lit_app.DatasetLoadersMap = {} + dataset_loaders["text_to_image"] = ( + dalle_data.DallePrompts, + dalle_data.DallePrompts.init_spec(), + ) + + lit_demo = dev_server.Server( + models=models, + model_loaders=model_loaders, + datasets=datasets, + dataset_loaders=dataset_loaders, + layouts=CUSTOM_LAYOUTS, + **server_flags.get_flags(), + ) + return lit_demo.serve() + + +if __name__ == "__main__": + app.run(main) diff --git a/lit_nlp/examples/dalle_mini/model.py b/lit_nlp/examples/dalle_mini/model.py new file mode 100644 index 00000000..487072d6 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/model.py @@ -0,0 +1,111 @@ +"""LIT wrappers for MiniDalleModel.""" + +from collections.abc import Iterable + +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.lib import image_utils +from min_dalle import MinDalle +import numpy as np +from PIL import Image +import torch + + +class DalleMiniModel(lit_model.Model): + """LIT model wrapper for Dalle-Mini Text-to-Image model. + + This wrapper simplifies the pipeline using Dalle-Mini for text-to-image + generation. + + + The basic flow within this model wrapper's predict() function is: + + + 1. Dalle-Mini processes the text prompt. + 2. Images are directly generated by Dalle-Mini. + """ + + def __init__( + self, + device: str = "cuda", # Use "cuda" for GPU or "cpu" for CPU + grid_size: int = 4, # each batch will generate grid_size**2 images + temperature: float = 0.5, + top_k: int = 256, + supercondition_factor: int = 32, + ): + super().__init__() + self.grid_size = grid_size + self.temperature = temperature + self.top_k = top_k + self.supercondition_factor = supercondition_factor + + # Load Dalle-Mini model + self.model = MinDalle( + models_root="./pretrained", + dtype=torch.float32, + device=device, + is_mega=True, + is_reusable=True, + ) + + def max_minibatch_size(self) -> int: + return 8 + + def predict( + self, inputs: Iterable[lit_types.JsonDict], **unused_kw + ) -> Iterable[lit_types.JsonDict]: + """Generate images based on the input prompts.""" + + def tensor_to_pil_image(tensor): + img_np = tensor.detach().cpu().numpy() + img_np = np.squeeze(img_np) + if img_np.ndim == 2: + img_np = np.stack([img_np] * 3, axis=-1) + elif img_np.ndim != 3 or img_np.shape[2] != 3: + raise ValueError( + f"Unexpected image shape: {img_np.shape}. Expected (H, W, 3)." + ) + + img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255 + img_np = img_np.clip(0, 255).astype(np.uint8) + return Image.fromarray(img_np) + + prompts = [ex["prompt"] for ex in inputs] + images = [] + for prompt in prompts: + # Generate images using the model + generated_images = self.model.generate_images( + text=prompt, + seed=-1, + grid_size=self.grid_size, + is_seamless=False, + temperature=self.temperature, + top_k=self.top_k, + supercondition_factor=self.supercondition_factor, + is_verbose=False, + ) + pil_images = [] + for img_tensor in generated_images: + pil_images.append(tensor_to_pil_image(img_tensor)) + images.append({ + "image": [ + image_utils.convert_pil_to_image_str(img) for img in pil_images + ], + "prompt": prompt, + }) + + return images + + def input_spec(self): + return { + "grid_size": lit_types.Scalar(), + "temperature": lit_types.Scalar(), + "top_k": lit_types.Scalar(), + "supercondition_factor": lit_types.Scalar(), + } + + def output_spec(self): + return { + "image": lit_types.ImageBytesList(), + "prompt": lit_types.TextSegment(), + } diff --git a/lit_nlp/examples/dalle_mini/requirements.txt b/lit_nlp/examples/dalle_mini/requirements.txt new file mode 100644 index 00000000..b5199a94 --- /dev/null +++ b/lit_nlp/examples/dalle_mini/requirements.txt @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +-r ../../../requirements.txt + +# Dalle-Mini dependencies +min_dalle==0.4.11