1from torch.utils.benchmark import Timer 2 3 4def time_with_torch_timer(fn, args, kwargs=None, iters=100): 5 kwargs = kwargs or {} 6 env = {"args": args, "kwargs": kwargs, "fn": fn} 7 fn_call = "fn(*args, **kwargs)" 8 9 # Measure end-to-end time 10 timer = Timer(stmt=f"{fn_call}", globals=env) 11 tt = timer.timeit(iters) 12 13 return tt 14