xref: /aosp_15_r20/external/pytorch/benchmarks/sparse/benchmark_semi_structured_sparsity.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import random
3
4import pandas as pd
5from tqdm import tqdm
6
7import torch
8import torch.utils.benchmark as benchmark
9from torch import nn
10from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
11
12
13torch.set_printoptions(
14    precision=2,
15    threshold=None,
16    edgeitems=16,
17    linewidth=480,
18    profile=None,
19    sci_mode=False,
20)
21
22
23# helper model definition for pruner
24class Model(nn.Module):
25    def __init__(self, m, k, dtype=None):
26        super().__init__()
27        # transposed so reversed
28        self.linear = nn.Linear(k, m)
29
30    def forward(self, x):
31        return self.linear(x)
32
33
34def rand_sparse_semi_structured_mask(
35    r, c, dtype=torch.float16, device="cuda", choice=None
36):
37    """
38    This function returns a 1:2 sparse matrix of size (r, c).
39    Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
40    """
41
42    choices = [[0, 1], [1, 0]]
43    mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
44
45    return (
46        torch.tensor(mask_entries, dtype=dtype, device=device)
47        .reshape(r, c)
48        .contiguous()
49    )
50
51
52def test_linear(m, k, n, dtype, contiguous, backend):
53    SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
54    mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
55    sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
56    input_tensor = torch.zeros(n, k).to(dtype).cuda()
57    model = Model(m, k).to(dtype).cuda().eval()
58
59    dense_measurement = benchmark.Timer(
60        stmt="model(input_tensor)",
61        globals=locals(),
62    ).blocked_autorange()
63
64    dense_output = model(input_tensor)
65    print(dense_output.shape)
66
67    # sparsify weights
68    model.linear.weight = nn.Parameter(
69        to_sparse_semi_structured(
70            sparse_weight,
71        )
72    )
73
74    sparse_output = model(input_tensor)
75    print(sparse_output.shape)
76
77    sparse_measurement = benchmark.Timer(
78        stmt="model(input_tensor)",
79        globals=locals(),
80    ).blocked_autorange()
81
82    correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
83
84    return {
85        "test_function": "linear",
86        "m": m,
87        "k": k,
88        "n": n,
89        "dtype": str(dtype),
90        "backend": backend,
91        "sparse_latency (ms)": sparse_measurement.median * 1000,
92        "dense_latency (ms)": dense_measurement.median * 1000,
93        "speedup (d/s)": dense_measurement.median / sparse_measurement.median,
94        "correct": correct,
95        "contiguous": sparse_output.is_contiguous(),
96    }
97
98
99def test_tensor(m, k, n, dtype, contiguous, backend):
100    A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
101    B = torch.zeros(k, n).to(dtype).cuda()
102    bias = torch.rand(n).to(dtype).cuda()
103
104    sA = to_sparse_semi_structured(A)
105
106    # torch.mm calculation
107    if dtype is not torch.int8:
108        dense_output = torch.mm(A, B)
109
110        dense_measurement = benchmark.Timer(
111            stmt="torch.mm(A, B)",
112            globals=locals(),
113        ).blocked_autorange()
114
115    else:
116        print("int8 baseline not supported")
117        dense_output = torch.mm(sA, B)
118
119        dense_measurement = benchmark.Timer(
120            stmt="torch.mm(sA, B)",
121            globals=locals(),
122        ).blocked_autorange()
123
124    sparse_output = torch.mm(sA, B)
125    sparse_measurement = benchmark.Timer(
126        stmt="torch.mm(sA, B)",
127        globals=locals(),
128    ).blocked_autorange()
129
130    correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
131
132    return {
133        "test_function": "tensor",
134        "m": m,
135        "k": k,
136        "n": n,
137        "dtype": str(dtype),
138        "backend": backend,
139        "sparse_latency (ms)": sparse_measurement.median * 1000,
140        "dense_latency (ms)": dense_measurement.median * 1000,
141        "speedup (d/s)": dense_measurement.median / sparse_measurement.median,
142        "correct": correct,
143        "contiguous": sparse_output.is_contiguous(),
144    }
145
146
147if __name__ == "__main__":
148    dtype_lookup = {
149        "int8": torch.int8,
150        "fp16": torch.float16,
151        "bf16": torch.bfloat16,
152        "fp32": torch.float32,
153    }
154
155    parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
156    parser.add_argument(
157        "--mode",
158        type=str,
159        choices=[
160            "nvidia-bert",
161            "nvidia-fixed-k",
162            "nvidia-fixed-mn",
163        ],
164    )
165    parser.add_argument(
166        "--dtype",
167        type=str,
168        choices=dtype_lookup.keys(),
169        default="fp16",
170    )
171    parser.add_argument(
172        "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
173    )
174    parser.add_argument("-contiguous", action="store_true")
175    parser.add_argument("-e2e", action="store_true")
176    parser.add_argument("-save", action="store_true")
177    args = parser.parse_args()
178
179    if args.e2e:
180        eval_fn = test_linear
181    else:
182        eval_fn = test_tensor
183
184    print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
185    dtype = dtype_lookup[args.dtype]
186
187    if args.mode == "nvidia-bert":
188        bert_shapes = [
189            (3072, 1024, 16384),
190            (4096, 1024, 16384),
191            (1024, 1024, 16384),
192            (1024, 4096, 16384),
193        ]
194        results = (
195            eval_fn(m, k, n, dtype, args.contiguous, args.backend)
196            for (m, k, n) in tqdm(bert_shapes)
197        )
198
199    elif args.mode == "nvidia-fixed-k":
200        mn_vals = [
201            3072,
202            4096,
203            5120,
204            6144,
205            7168,
206            8192,
207            9216,
208            10240,
209            11264,
210            12288,
211            13312,
212            14336,
213            15360,
214            16384,
215            17408,
216            18432,
217            19456,
218            20480,
219        ]
220        results = (
221            eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
222            for mn in tqdm(mn_vals)
223        )
224
225    elif args.mode == "nvidia-fixed-mn":
226        k_vals = [
227            2560,
228            3840,
229            5120,
230            6400,
231            7680,
232            8960,
233            10240,
234            11520,
235            12800,
236            14080,
237            15360,
238            16640,
239            17920,
240            19200,
241            20480,
242        ]
243        results = (
244            eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
245            for k in tqdm(k_vals)
246        )
247
248    df = pd.DataFrame.from_records(results)
249    if args.save:
250        save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
251        df.to_csv(save_file)
252        print(f"Finished benchmark: {args.mode} saved results to {save_file}")
253    print(df)
254