1import argparse 2import random 3 4import pandas as pd 5from tqdm import tqdm 6 7import torch 8import torch.utils.benchmark as benchmark 9from torch import nn 10from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured 11 12 13torch.set_printoptions( 14 precision=2, 15 threshold=None, 16 edgeitems=16, 17 linewidth=480, 18 profile=None, 19 sci_mode=False, 20) 21 22 23# helper model definition for pruner 24class Model(nn.Module): 25 def __init__(self, m, k, dtype=None): 26 super().__init__() 27 # transposed so reversed 28 self.linear = nn.Linear(k, m) 29 30 def forward(self, x): 31 return self.linear(x) 32 33 34def rand_sparse_semi_structured_mask( 35 r, c, dtype=torch.float16, device="cuda", choice=None 36): 37 """ 38 This function returns a 1:2 sparse matrix of size (r, c). 39 Note that this means this matrix will also be 2:4 and 4:8 sparse as well. 40 """ 41 42 choices = [[0, 1], [1, 0]] 43 mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] 44 45 return ( 46 torch.tensor(mask_entries, dtype=dtype, device=device) 47 .reshape(r, c) 48 .contiguous() 49 ) 50 51 52def test_linear(m, k, n, dtype, contiguous, backend): 53 SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" 54 mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) 55 sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask 56 input_tensor = torch.zeros(n, k).to(dtype).cuda() 57 model = Model(m, k).to(dtype).cuda().eval() 58 59 dense_measurement = benchmark.Timer( 60 stmt="model(input_tensor)", 61 globals=locals(), 62 ).blocked_autorange() 63 64 dense_output = model(input_tensor) 65 print(dense_output.shape) 66 67 # sparsify weights 68 model.linear.weight = nn.Parameter( 69 to_sparse_semi_structured( 70 sparse_weight, 71 ) 72 ) 73 74 sparse_output = model(input_tensor) 75 print(sparse_output.shape) 76 77 sparse_measurement = benchmark.Timer( 78 stmt="model(input_tensor)", 79 globals=locals(), 80 ).blocked_autorange() 81 82 correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) 83 84 return { 85 "test_function": "linear", 86 "m": m, 87 "k": k, 88 "n": n, 89 "dtype": str(dtype), 90 "backend": backend, 91 "sparse_latency (ms)": sparse_measurement.median * 1000, 92 "dense_latency (ms)": dense_measurement.median * 1000, 93 "speedup (d/s)": dense_measurement.median / sparse_measurement.median, 94 "correct": correct, 95 "contiguous": sparse_output.is_contiguous(), 96 } 97 98 99def test_tensor(m, k, n, dtype, contiguous, backend): 100 A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) 101 B = torch.zeros(k, n).to(dtype).cuda() 102 bias = torch.rand(n).to(dtype).cuda() 103 104 sA = to_sparse_semi_structured(A) 105 106 # torch.mm calculation 107 if dtype is not torch.int8: 108 dense_output = torch.mm(A, B) 109 110 dense_measurement = benchmark.Timer( 111 stmt="torch.mm(A, B)", 112 globals=locals(), 113 ).blocked_autorange() 114 115 else: 116 print("int8 baseline not supported") 117 dense_output = torch.mm(sA, B) 118 119 dense_measurement = benchmark.Timer( 120 stmt="torch.mm(sA, B)", 121 globals=locals(), 122 ).blocked_autorange() 123 124 sparse_output = torch.mm(sA, B) 125 sparse_measurement = benchmark.Timer( 126 stmt="torch.mm(sA, B)", 127 globals=locals(), 128 ).blocked_autorange() 129 130 correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) 131 132 return { 133 "test_function": "tensor", 134 "m": m, 135 "k": k, 136 "n": n, 137 "dtype": str(dtype), 138 "backend": backend, 139 "sparse_latency (ms)": sparse_measurement.median * 1000, 140 "dense_latency (ms)": dense_measurement.median * 1000, 141 "speedup (d/s)": dense_measurement.median / sparse_measurement.median, 142 "correct": correct, 143 "contiguous": sparse_output.is_contiguous(), 144 } 145 146 147if __name__ == "__main__": 148 dtype_lookup = { 149 "int8": torch.int8, 150 "fp16": torch.float16, 151 "bf16": torch.bfloat16, 152 "fp32": torch.float32, 153 } 154 155 parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") 156 parser.add_argument( 157 "--mode", 158 type=str, 159 choices=[ 160 "nvidia-bert", 161 "nvidia-fixed-k", 162 "nvidia-fixed-mn", 163 ], 164 ) 165 parser.add_argument( 166 "--dtype", 167 type=str, 168 choices=dtype_lookup.keys(), 169 default="fp16", 170 ) 171 parser.add_argument( 172 "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" 173 ) 174 parser.add_argument("-contiguous", action="store_true") 175 parser.add_argument("-e2e", action="store_true") 176 parser.add_argument("-save", action="store_true") 177 args = parser.parse_args() 178 179 if args.e2e: 180 eval_fn = test_linear 181 else: 182 eval_fn = test_tensor 183 184 print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") 185 dtype = dtype_lookup[args.dtype] 186 187 if args.mode == "nvidia-bert": 188 bert_shapes = [ 189 (3072, 1024, 16384), 190 (4096, 1024, 16384), 191 (1024, 1024, 16384), 192 (1024, 4096, 16384), 193 ] 194 results = ( 195 eval_fn(m, k, n, dtype, args.contiguous, args.backend) 196 for (m, k, n) in tqdm(bert_shapes) 197 ) 198 199 elif args.mode == "nvidia-fixed-k": 200 mn_vals = [ 201 3072, 202 4096, 203 5120, 204 6144, 205 7168, 206 8192, 207 9216, 208 10240, 209 11264, 210 12288, 211 13312, 212 14336, 213 15360, 214 16384, 215 17408, 216 18432, 217 19456, 218 20480, 219 ] 220 results = ( 221 eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) 222 for mn in tqdm(mn_vals) 223 ) 224 225 elif args.mode == "nvidia-fixed-mn": 226 k_vals = [ 227 2560, 228 3840, 229 5120, 230 6400, 231 7680, 232 8960, 233 10240, 234 11520, 235 12800, 236 14080, 237 15360, 238 16640, 239 17920, 240 19200, 241 20480, 242 ] 243 results = ( 244 eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) 245 for k in tqdm(k_vals) 246 ) 247 248 df = pd.DataFrame.from_records(results) 249 if args.save: 250 save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" 251 df.to_csv(save_file) 252 print(f"Finished benchmark: {args.mode} saved results to {save_file}") 253 print(df) 254