1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9from multiprocessing.connection import Client 10 11import numpy as np 12import piq 13import torch 14from diffusers import EulerDiscreteScheduler, UNet2DConditionModel 15from diffusers.models.embeddings import get_timestep_embedding 16 17from executorch.backends.qualcomm.utils.utils import ( 18 ExecutorchBackendConfig, 19 from_context_binary, 20 generate_htp_compiler_spec, 21 generate_qnn_executorch_compiler_spec, 22 get_soc_to_chipset_map, 23 QcomChipset, 24) 25 26from executorch.examples.qualcomm.qaihub_scripts.stable_diffusion.stable_diffusion_lib import ( 27 StableDiffusion, 28) 29from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( 30 gen_pte_from_ctx_bin, 31 get_encoding, 32) 33from executorch.examples.qualcomm.utils import ( 34 setup_common_args_and_variables, 35 SimpleADB, 36) 37from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 38from PIL import Image 39from torchvision.transforms import ToTensor 40 41target_names = ("text_encoder", "unet", "vae") 42 43 44def get_quant_data( 45 encoding: dict, data: torch.Tensor, input_model: str, input_index: int 46): 47 scale = encoding[f"{input_model}_input"]["scale"][input_index] 48 offset = encoding[f"{input_model}_input"]["offset"][input_index] 49 if offset < 0: 50 quant_data = data.div(scale).sub(offset).clip(min=0, max=65535).detach() 51 else: 52 quant_data = data.div(scale).add(offset).clip(min=0, max=65535).detach() 53 54 return quant_data.to(dtype=torch.uint16) 55 56 57def get_encodings( 58 path_to_shard_encoder: str, 59 path_to_shard_unet: str, 60 path_to_shard_vae: str, 61 compiler_specs, 62): 63 text_encoder_encoding = get_encoding( 64 path_to_shard=path_to_shard_encoder, 65 compiler_specs=compiler_specs, 66 get_input=False, 67 get_output=True, 68 num_input=1, 69 num_output=1, 70 ) 71 unet_encoding = get_encoding( 72 path_to_shard=path_to_shard_unet, 73 compiler_specs=compiler_specs, 74 get_input=True, 75 get_output=True, 76 num_input=3, 77 num_output=1, 78 ) 79 vae_encoding = get_encoding( 80 path_to_shard=path_to_shard_vae, 81 compiler_specs=compiler_specs, 82 get_input=True, 83 get_output=True, 84 num_input=1, 85 num_output=1, 86 ) 87 88 return ( 89 text_encoder_encoding[0], 90 unet_encoding[0], 91 unet_encoding[1], 92 vae_encoding[0], 93 vae_encoding[1], 94 ) 95 96 97def get_time_embedding(timestep, time_embedding): 98 timestep = torch.tensor([timestep]) 99 t_emb = get_timestep_embedding(timestep, 320, True, 0) 100 emb = time_embedding(t_emb) 101 102 return emb 103 104 105def build_args_parser(): 106 parser = setup_common_args_and_variables() 107 108 parser.add_argument( 109 "-a", 110 "--artifact", 111 help="Path for storing generated artifacts by this example. Default ./stable_diffusion_qai_hub", 112 default="./stable_diffusion_qai_hub", 113 type=str, 114 ) 115 116 parser.add_argument( 117 "--pte_prefix", 118 help="Prefix of pte files name. Default qaihub_stable_diffusion", 119 default="qaihub_stable_diffusion", 120 type=str, 121 ) 122 123 parser.add_argument( 124 "--text_encoder_bin", 125 type=str, 126 default=None, 127 help="[For AI hub ctx binary] Path to Text Encoder.", 128 required=True, 129 ) 130 131 parser.add_argument( 132 "--unet_bin", 133 type=str, 134 default=None, 135 help="[For AI hub ctx binary] Path to UNet.", 136 required=True, 137 ) 138 139 parser.add_argument( 140 "--vae_bin", 141 type=str, 142 default=None, 143 help="[For AI hub ctx binary] Path to Vae Decoder.", 144 required=True, 145 ) 146 147 parser.add_argument( 148 "--prompt", 149 default="a photo of an astronaut riding a horse on mars", 150 type=str, 151 help="Prompt to generate image from.", 152 ) 153 154 parser.add_argument( 155 "--num_time_steps", 156 default=20, 157 type=int, 158 help="The number of diffusion time steps.", 159 ) 160 161 parser.add_argument( 162 "--guidance_scale", 163 type=float, 164 default=7.5, 165 help="Strength of guidance (higher means more influence from prompt).", 166 ) 167 168 parser.add_argument( 169 "--vocab_json", 170 type=str, 171 help="Path to tokenizer vocab.json file. Can get vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main", 172 required=True, 173 ) 174 175 parser.add_argument( 176 "--pre_gen_pte", 177 help="folder path to pre-compiled ptes", 178 default=None, 179 type=str, 180 ) 181 182 parser.add_argument( 183 "--fix_latents", 184 help="Enable this option to fix the latents in the unet diffuse step.", 185 action="store_true", 186 ) 187 188 return parser 189 190 191def broadcast_ut_result(output_image, seed): 192 sd = StableDiffusion(seed) 193 to_tensor = ToTensor() 194 target = sd(args.prompt, 512, 512, args.num_time_steps) 195 target = to_tensor(target).unsqueeze(0) 196 output_tensor = to_tensor( 197 Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) 198 ).unsqueeze(0) 199 200 psnr_piq = piq.psnr(target, output_tensor) 201 ssim_piq = piq.ssim(target, output_tensor) 202 print(f"PSNR: {round(psnr_piq.item(), 3)}, SSIM: {round(ssim_piq.item(), 3)}") 203 if args.ip and args.port != -1: 204 with Client((args.ip, args.port)) as conn: 205 conn.send(json.dumps({"PSNR": psnr_piq.item(), "SSIM": ssim_piq.item()})) 206 207 208def save_result(output_image): 209 img = Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) 210 save_path = f"{args.artifact}/outputs/output_image.jpg" 211 img.save(save_path) 212 print(f"Output image saved at {save_path}") 213 214 215def inference(args, compiler_specs, pte_files): 216 # Loading a pretrained EulerDiscreteScheduler from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. 217 scheduler = EulerDiscreteScheduler.from_pretrained( 218 "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", revision="main" 219 ) 220 221 # Loading a pretrained UNet2DConditionModel (which includes the time embedding) from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. 222 time_embedding = UNet2DConditionModel.from_pretrained( 223 "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision="main" 224 ).time_embedding 225 226 scheduler.set_timesteps(args.num_time_steps) 227 scheduler.config.prediction_type = "epsilon" 228 # Get encoding of unet and vae 229 ( 230 encoder_output, 231 unet_input, 232 unet_output, 233 vae_input, 234 vae_output, 235 ) = get_encodings( 236 args.text_encoder_bin, 237 args.unet_bin, 238 args.vae_bin, 239 compiler_specs, 240 ) 241 encoding = { 242 "encoder_output": encoder_output, 243 "unet_input": unet_input, 244 "unet_output": unet_output, 245 "vae_input": vae_input, 246 "vae_output": vae_output, 247 } 248 249 adb = SimpleADB( 250 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 251 build_path=args.build_folder, 252 pte_path=pte_files, 253 workspace=f"/data/local/tmp/executorch/{args.pte_prefix}", 254 device_id=args.device, 255 host_id=args.host, 256 soc_model=args.model, 257 runner="examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner", 258 ) 259 260 input_unet = () 261 input_list_unet = "" 262 263 for i, t in enumerate(scheduler.timesteps): 264 time_emb = get_quant_data( 265 encoding, get_time_embedding(t, time_embedding), "unet", 1 266 ) 267 input_list_unet += f"input_{i}_0.raw\n" 268 input_unet = input_unet + (time_emb,) 269 270 qnn_executor_runner_args = [ 271 f"--text_encoder_path {adb.workspace}/{args.pte_prefix}_text_encoder.pte", 272 f"--unet_path {adb.workspace}/{args.pte_prefix}_unet.pte", 273 f"--vae_path {adb.workspace}/{args.pte_prefix}_vae.pte", 274 f"--input_list_path {adb.workspace}/input_list.txt", 275 f"--output_folder_path {adb.output_folder}", 276 f'--prompt "{args.prompt}"', 277 f"--guidance_scale {args.guidance_scale}", 278 f"--num_time_steps {args.num_time_steps}", 279 f"--vocab_json {adb.workspace}/vocab.json", 280 ] 281 if args.fix_latents: 282 qnn_executor_runner_args.append("--fix_latents") 283 284 text_encoder_output_scale = encoding["encoder_output"]["scale"][0] 285 text_encoder_output_offset = encoding["encoder_output"]["offset"][0] 286 unet_input_latent_scale = encoding["unet_input"]["scale"][0] 287 unet_input_latent_offset = encoding["unet_input"]["offset"][0] 288 unet_input_text_emb_scale = encoding["unet_input"]["scale"][2] 289 unet_input_text_emb_offset = encoding["unet_input"]["offset"][2] 290 unet_output_scale = encoding["unet_output"]["scale"][0] 291 unet_output_offset = encoding["unet_output"]["offset"][0] 292 vae_input_scale = encoding["vae_input"]["scale"][0] 293 vae_input_offset = encoding["vae_input"]["offset"][0] 294 vae_output_scale = encoding["vae_output"]["scale"][0] 295 vae_output_offset = encoding["vae_output"]["offset"][0] 296 297 qnn_executor_runner_args = qnn_executor_runner_args + [ 298 f"--text_encoder_output_scale {text_encoder_output_scale}", 299 f"--text_encoder_output_offset {text_encoder_output_offset}", 300 f"--unet_input_latent_scale {unet_input_latent_scale}", 301 f"--unet_input_latent_offset {unet_input_latent_offset}", 302 f"--unet_input_text_emb_scale {unet_input_text_emb_scale}", 303 f"--unet_input_text_emb_offset {unet_input_text_emb_offset}", 304 f"--unet_output_scale {unet_output_scale}", 305 f"--unet_output_offset {unet_output_offset}", 306 f"--vae_input_scale {vae_input_scale}", 307 f"--vae_input_offset {vae_input_offset}", 308 f"--vae_output_scale {vae_output_scale}", 309 f"--vae_output_offset {vae_output_offset}", 310 ] 311 312 qnn_executor_runner_args = " ".join( 313 [ 314 f"cd {adb.workspace} &&", 315 f"./qaihub_stable_diffusion_runner {' '.join(qnn_executor_runner_args)}", 316 ] 317 ) 318 319 files = [args.vocab_json] 320 321 if args.fix_latents: 322 seed = 42 323 latents = torch.randn((1, 4, 64, 64), generator=torch.manual_seed(seed)).to( 324 "cpu" 325 ) 326 # We need to explicitly permute after init tensor or else the random value will be different 327 latents = latents.permute(0, 2, 3, 1).contiguous() 328 latents = latents * scheduler.init_noise_sigma 329 flattened_tensor = latents.view(-1) 330 # Save the flattened tensor to a .raw file 331 with open(os.path.join(args.artifact, "latents.raw"), "wb") as file: 332 file.write(flattened_tensor.numpy().tobytes()) 333 files.append(os.path.join(args.artifact, "latents.raw")) 334 335 if not args.skip_push: 336 adb.push(inputs=input_unet, input_list=input_list_unet, files=files) 337 adb.execute(custom_runner_cmd=qnn_executor_runner_args) 338 339 output_image = [] 340 341 def post_process_vae(): 342 with open(f"{args.artifact}/outputs/output_0_0.raw", "rb") as f: 343 output_image.append( 344 np.fromfile(f, dtype=np.float32).reshape(1, 512, 512, 3) 345 ) 346 347 adb.pull(output_path=args.artifact, callback=post_process_vae) 348 349 if args.fix_latents: 350 broadcast_ut_result(output_image, seed) 351 else: 352 save_result(output_image) 353 354 355def main(args): 356 os.makedirs(args.artifact, exist_ok=True) 357 # common part for compile & inference 358 backend_options = generate_htp_compiler_spec( 359 use_fp16=False, 360 use_multi_contexts=True, 361 ) 362 compiler_specs = generate_qnn_executorch_compiler_spec( 363 soc_model=getattr(QcomChipset, args.model), 364 backend_options=backend_options, 365 is_from_context_binary=True, 366 ) 367 368 if args.pre_gen_pte is None: 369 # Create custom operators as context loader 370 soc_model = get_soc_to_chipset_map()[args.model] 371 bundle_programs = [ 372 from_context_binary(args.text_encoder_bin, "ctx_loader_0", soc_model), 373 from_context_binary(args.unet_bin, "ctx_loader_1", soc_model), 374 from_context_binary(args.vae_bin, "ctx_loader_2", soc_model), 375 ] 376 pte_names = [f"{args.pte_prefix}_{target_name}" for target_name in target_names] 377 memory_planning_pass = MemoryPlanningPass( 378 alloc_graph_input=False, 379 alloc_graph_output=False, 380 ) 381 pte_files = gen_pte_from_ctx_bin( 382 artifact=args.artifact, 383 pte_names=pte_names, 384 bundle_programs=bundle_programs, 385 backend_config=ExecutorchBackendConfig( 386 memory_planning_pass=memory_planning_pass 387 ), 388 ) 389 assert ( 390 len(pte_files) == 3 391 ), f"Error: Expected 3 PTE files, but got {len(pte_files)} files." 392 393 else: 394 pte_files = [ 395 f"{args.pre_gen_pte}/{args.pte_prefix}_{target_name}.pte" 396 for target_name in target_names 397 ] 398 if args.compile_only: 399 return 400 401 inference(args, compiler_specs, pte_files) 402 403 404if __name__ == "__main__": # noqa: C901 405 parser = build_args_parser() 406 args = parser.parse_args() 407 408 try: 409 main(args) 410 except Exception as e: 411 if args.ip and args.port != -1: 412 with Client((args.ip, args.port)) as conn: 413 conn.send(json.dumps({"Error": str(e)})) 414 else: 415 raise Exception(e) 416