1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12import piq
13import torch
14from diffusers import EulerDiscreteScheduler, UNet2DConditionModel
15from diffusers.models.embeddings import get_timestep_embedding
16
17from executorch.backends.qualcomm.utils.utils import (
18    ExecutorchBackendConfig,
19    from_context_binary,
20    generate_htp_compiler_spec,
21    generate_qnn_executorch_compiler_spec,
22    get_soc_to_chipset_map,
23    QcomChipset,
24)
25
26from executorch.examples.qualcomm.qaihub_scripts.stable_diffusion.stable_diffusion_lib import (
27    StableDiffusion,
28)
29from executorch.examples.qualcomm.qaihub_scripts.utils.utils import (
30    gen_pte_from_ctx_bin,
31    get_encoding,
32)
33from executorch.examples.qualcomm.utils import (
34    setup_common_args_and_variables,
35    SimpleADB,
36)
37from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
38from PIL import Image
39from torchvision.transforms import ToTensor
40
41target_names = ("text_encoder", "unet", "vae")
42
43
44def get_quant_data(
45    encoding: dict, data: torch.Tensor, input_model: str, input_index: int
46):
47    scale = encoding[f"{input_model}_input"]["scale"][input_index]
48    offset = encoding[f"{input_model}_input"]["offset"][input_index]
49    if offset < 0:
50        quant_data = data.div(scale).sub(offset).clip(min=0, max=65535).detach()
51    else:
52        quant_data = data.div(scale).add(offset).clip(min=0, max=65535).detach()
53
54    return quant_data.to(dtype=torch.uint16)
55
56
57def get_encodings(
58    path_to_shard_encoder: str,
59    path_to_shard_unet: str,
60    path_to_shard_vae: str,
61    compiler_specs,
62):
63    text_encoder_encoding = get_encoding(
64        path_to_shard=path_to_shard_encoder,
65        compiler_specs=compiler_specs,
66        get_input=False,
67        get_output=True,
68        num_input=1,
69        num_output=1,
70    )
71    unet_encoding = get_encoding(
72        path_to_shard=path_to_shard_unet,
73        compiler_specs=compiler_specs,
74        get_input=True,
75        get_output=True,
76        num_input=3,
77        num_output=1,
78    )
79    vae_encoding = get_encoding(
80        path_to_shard=path_to_shard_vae,
81        compiler_specs=compiler_specs,
82        get_input=True,
83        get_output=True,
84        num_input=1,
85        num_output=1,
86    )
87
88    return (
89        text_encoder_encoding[0],
90        unet_encoding[0],
91        unet_encoding[1],
92        vae_encoding[0],
93        vae_encoding[1],
94    )
95
96
97def get_time_embedding(timestep, time_embedding):
98    timestep = torch.tensor([timestep])
99    t_emb = get_timestep_embedding(timestep, 320, True, 0)
100    emb = time_embedding(t_emb)
101
102    return emb
103
104
105def build_args_parser():
106    parser = setup_common_args_and_variables()
107
108    parser.add_argument(
109        "-a",
110        "--artifact",
111        help="Path for storing generated artifacts by this example. Default ./stable_diffusion_qai_hub",
112        default="./stable_diffusion_qai_hub",
113        type=str,
114    )
115
116    parser.add_argument(
117        "--pte_prefix",
118        help="Prefix of pte files name. Default qaihub_stable_diffusion",
119        default="qaihub_stable_diffusion",
120        type=str,
121    )
122
123    parser.add_argument(
124        "--text_encoder_bin",
125        type=str,
126        default=None,
127        help="[For AI hub ctx binary] Path to Text Encoder.",
128        required=True,
129    )
130
131    parser.add_argument(
132        "--unet_bin",
133        type=str,
134        default=None,
135        help="[For AI hub ctx binary] Path to UNet.",
136        required=True,
137    )
138
139    parser.add_argument(
140        "--vae_bin",
141        type=str,
142        default=None,
143        help="[For AI hub ctx binary] Path to Vae Decoder.",
144        required=True,
145    )
146
147    parser.add_argument(
148        "--prompt",
149        default="a photo of an astronaut riding a horse on mars",
150        type=str,
151        help="Prompt to generate image from.",
152    )
153
154    parser.add_argument(
155        "--num_time_steps",
156        default=20,
157        type=int,
158        help="The number of diffusion time steps.",
159    )
160
161    parser.add_argument(
162        "--guidance_scale",
163        type=float,
164        default=7.5,
165        help="Strength of guidance (higher means more influence from prompt).",
166    )
167
168    parser.add_argument(
169        "--vocab_json",
170        type=str,
171        help="Path to tokenizer vocab.json file. Can get vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main",
172        required=True,
173    )
174
175    parser.add_argument(
176        "--pre_gen_pte",
177        help="folder path to pre-compiled ptes",
178        default=None,
179        type=str,
180    )
181
182    parser.add_argument(
183        "--fix_latents",
184        help="Enable this option to fix the latents in the unet diffuse step.",
185        action="store_true",
186    )
187
188    return parser
189
190
191def broadcast_ut_result(output_image, seed):
192    sd = StableDiffusion(seed)
193    to_tensor = ToTensor()
194    target = sd(args.prompt, 512, 512, args.num_time_steps)
195    target = to_tensor(target).unsqueeze(0)
196    output_tensor = to_tensor(
197        Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0])
198    ).unsqueeze(0)
199
200    psnr_piq = piq.psnr(target, output_tensor)
201    ssim_piq = piq.ssim(target, output_tensor)
202    print(f"PSNR: {round(psnr_piq.item(), 3)}, SSIM: {round(ssim_piq.item(), 3)}")
203    if args.ip and args.port != -1:
204        with Client((args.ip, args.port)) as conn:
205            conn.send(json.dumps({"PSNR": psnr_piq.item(), "SSIM": ssim_piq.item()}))
206
207
208def save_result(output_image):
209    img = Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0])
210    save_path = f"{args.artifact}/outputs/output_image.jpg"
211    img.save(save_path)
212    print(f"Output image saved at {save_path}")
213
214
215def inference(args, compiler_specs, pte_files):
216    # Loading a pretrained EulerDiscreteScheduler from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base.
217    scheduler = EulerDiscreteScheduler.from_pretrained(
218        "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", revision="main"
219    )
220
221    #  Loading a pretrained UNet2DConditionModel (which includes the time embedding) from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base.
222    time_embedding = UNet2DConditionModel.from_pretrained(
223        "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision="main"
224    ).time_embedding
225
226    scheduler.set_timesteps(args.num_time_steps)
227    scheduler.config.prediction_type = "epsilon"
228    # Get encoding of unet and vae
229    (
230        encoder_output,
231        unet_input,
232        unet_output,
233        vae_input,
234        vae_output,
235    ) = get_encodings(
236        args.text_encoder_bin,
237        args.unet_bin,
238        args.vae_bin,
239        compiler_specs,
240    )
241    encoding = {
242        "encoder_output": encoder_output,
243        "unet_input": unet_input,
244        "unet_output": unet_output,
245        "vae_input": vae_input,
246        "vae_output": vae_output,
247    }
248
249    adb = SimpleADB(
250        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
251        build_path=args.build_folder,
252        pte_path=pte_files,
253        workspace=f"/data/local/tmp/executorch/{args.pte_prefix}",
254        device_id=args.device,
255        host_id=args.host,
256        soc_model=args.model,
257        runner="examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner",
258    )
259
260    input_unet = ()
261    input_list_unet = ""
262
263    for i, t in enumerate(scheduler.timesteps):
264        time_emb = get_quant_data(
265            encoding, get_time_embedding(t, time_embedding), "unet", 1
266        )
267        input_list_unet += f"input_{i}_0.raw\n"
268        input_unet = input_unet + (time_emb,)
269
270    qnn_executor_runner_args = [
271        f"--text_encoder_path {adb.workspace}/{args.pte_prefix}_text_encoder.pte",
272        f"--unet_path {adb.workspace}/{args.pte_prefix}_unet.pte",
273        f"--vae_path {adb.workspace}/{args.pte_prefix}_vae.pte",
274        f"--input_list_path {adb.workspace}/input_list.txt",
275        f"--output_folder_path {adb.output_folder}",
276        f'--prompt "{args.prompt}"',
277        f"--guidance_scale {args.guidance_scale}",
278        f"--num_time_steps {args.num_time_steps}",
279        f"--vocab_json {adb.workspace}/vocab.json",
280    ]
281    if args.fix_latents:
282        qnn_executor_runner_args.append("--fix_latents")
283
284    text_encoder_output_scale = encoding["encoder_output"]["scale"][0]
285    text_encoder_output_offset = encoding["encoder_output"]["offset"][0]
286    unet_input_latent_scale = encoding["unet_input"]["scale"][0]
287    unet_input_latent_offset = encoding["unet_input"]["offset"][0]
288    unet_input_text_emb_scale = encoding["unet_input"]["scale"][2]
289    unet_input_text_emb_offset = encoding["unet_input"]["offset"][2]
290    unet_output_scale = encoding["unet_output"]["scale"][0]
291    unet_output_offset = encoding["unet_output"]["offset"][0]
292    vae_input_scale = encoding["vae_input"]["scale"][0]
293    vae_input_offset = encoding["vae_input"]["offset"][0]
294    vae_output_scale = encoding["vae_output"]["scale"][0]
295    vae_output_offset = encoding["vae_output"]["offset"][0]
296
297    qnn_executor_runner_args = qnn_executor_runner_args + [
298        f"--text_encoder_output_scale {text_encoder_output_scale}",
299        f"--text_encoder_output_offset {text_encoder_output_offset}",
300        f"--unet_input_latent_scale {unet_input_latent_scale}",
301        f"--unet_input_latent_offset {unet_input_latent_offset}",
302        f"--unet_input_text_emb_scale {unet_input_text_emb_scale}",
303        f"--unet_input_text_emb_offset {unet_input_text_emb_offset}",
304        f"--unet_output_scale {unet_output_scale}",
305        f"--unet_output_offset {unet_output_offset}",
306        f"--vae_input_scale {vae_input_scale}",
307        f"--vae_input_offset {vae_input_offset}",
308        f"--vae_output_scale {vae_output_scale}",
309        f"--vae_output_offset {vae_output_offset}",
310    ]
311
312    qnn_executor_runner_args = " ".join(
313        [
314            f"cd {adb.workspace} &&",
315            f"./qaihub_stable_diffusion_runner {' '.join(qnn_executor_runner_args)}",
316        ]
317    )
318
319    files = [args.vocab_json]
320
321    if args.fix_latents:
322        seed = 42
323        latents = torch.randn((1, 4, 64, 64), generator=torch.manual_seed(seed)).to(
324            "cpu"
325        )
326        # We need to explicitly permute after init tensor or else the random value will be different
327        latents = latents.permute(0, 2, 3, 1).contiguous()
328        latents = latents * scheduler.init_noise_sigma
329        flattened_tensor = latents.view(-1)
330        # Save the flattened tensor to a .raw file
331        with open(os.path.join(args.artifact, "latents.raw"), "wb") as file:
332            file.write(flattened_tensor.numpy().tobytes())
333        files.append(os.path.join(args.artifact, "latents.raw"))
334
335    if not args.skip_push:
336        adb.push(inputs=input_unet, input_list=input_list_unet, files=files)
337    adb.execute(custom_runner_cmd=qnn_executor_runner_args)
338
339    output_image = []
340
341    def post_process_vae():
342        with open(f"{args.artifact}/outputs/output_0_0.raw", "rb") as f:
343            output_image.append(
344                np.fromfile(f, dtype=np.float32).reshape(1, 512, 512, 3)
345            )
346
347    adb.pull(output_path=args.artifact, callback=post_process_vae)
348
349    if args.fix_latents:
350        broadcast_ut_result(output_image, seed)
351    else:
352        save_result(output_image)
353
354
355def main(args):
356    os.makedirs(args.artifact, exist_ok=True)
357    # common part for compile & inference
358    backend_options = generate_htp_compiler_spec(
359        use_fp16=False,
360        use_multi_contexts=True,
361    )
362    compiler_specs = generate_qnn_executorch_compiler_spec(
363        soc_model=getattr(QcomChipset, args.model),
364        backend_options=backend_options,
365        is_from_context_binary=True,
366    )
367
368    if args.pre_gen_pte is None:
369        # Create custom operators as context loader
370        soc_model = get_soc_to_chipset_map()[args.model]
371        bundle_programs = [
372            from_context_binary(args.text_encoder_bin, "ctx_loader_0", soc_model),
373            from_context_binary(args.unet_bin, "ctx_loader_1", soc_model),
374            from_context_binary(args.vae_bin, "ctx_loader_2", soc_model),
375        ]
376        pte_names = [f"{args.pte_prefix}_{target_name}" for target_name in target_names]
377        memory_planning_pass = MemoryPlanningPass(
378            alloc_graph_input=False,
379            alloc_graph_output=False,
380        )
381        pte_files = gen_pte_from_ctx_bin(
382            artifact=args.artifact,
383            pte_names=pte_names,
384            bundle_programs=bundle_programs,
385            backend_config=ExecutorchBackendConfig(
386                memory_planning_pass=memory_planning_pass
387            ),
388        )
389        assert (
390            len(pte_files) == 3
391        ), f"Error: Expected 3 PTE files, but got {len(pte_files)} files."
392
393    else:
394        pte_files = [
395            f"{args.pre_gen_pte}/{args.pte_prefix}_{target_name}.pte"
396            for target_name in target_names
397        ]
398    if args.compile_only:
399        return
400
401    inference(args, compiler_specs, pte_files)
402
403
404if __name__ == "__main__":  # noqa: C901
405    parser = build_args_parser()
406    args = parser.parse_args()
407
408    try:
409        main(args)
410    except Exception as e:
411        if args.ip and args.port != -1:
412            with Client((args.ip, args.port)) as conn:
413                conn.send(json.dumps({"Error": str(e)}))
414        else:
415            raise Exception(e)
416