xref: /aosp_15_r20/external/clpeak/src/compute_hp.cpp (revision 1cd03ba3888297bc945f2c84574e105e3ced3e34)
1  #include <clpeak.h>
2  
runComputeHP(cl::CommandQueue & queue,cl::Program & prog,device_info_t & devInfo)3  int clPeak::runComputeHP(cl::CommandQueue &queue, cl::Program &prog, device_info_t &devInfo)
4  {
5    float timed, gflops;
6    cl_uint workPerWI;
7    cl::NDRange globalSize, localSize;
8    cl_float A = 1.3f;
9    uint iters = devInfo.computeIters;
10  
11    if (!isComputeHP)
12      return 0;
13  
14    if (!devInfo.halfSupported)
15    {
16      log->print(NEWLINE TAB TAB "No half precision support! Skipped" NEWLINE);
17      return 0;
18    }
19  
20    try
21    {
22      log->print(NEWLINE TAB TAB "Half-precision compute (GFLOPS)" NEWLINE);
23      log->xmlOpenTag("half_precision_compute");
24      log->xmlAppendAttribs("unit", "gflops");
25  
26      cl::Context ctx = queue.getInfo<CL_QUEUE_CONTEXT>();
27  
28      uint64_t globalWIs = (devInfo.numCUs) * (devInfo.computeWgsPerCU) * (devInfo.maxWGSize);
29      uint64_t t = std::min((globalWIs * sizeof(cl_half)), devInfo.maxAllocSize) / sizeof(cl_half);
30      globalWIs = roundToMultipleOf(t, devInfo.maxWGSize);
31  
32      cl::Buffer outputBuf = cl::Buffer(ctx, CL_MEM_WRITE_ONLY, (globalWIs * sizeof(cl_half)));
33  
34      globalSize = globalWIs;
35      localSize = devInfo.maxWGSize;
36  
37      cl::Kernel kernel_v1(prog, "compute_hp_v1");
38      kernel_v1.setArg(0, outputBuf), kernel_v1.setArg(1, A);
39  
40      cl::Kernel kernel_v2(prog, "compute_hp_v2");
41      kernel_v2.setArg(0, outputBuf), kernel_v2.setArg(1, A);
42  
43      cl::Kernel kernel_v4(prog, "compute_hp_v4");
44      kernel_v4.setArg(0, outputBuf), kernel_v4.setArg(1, A);
45  
46      cl::Kernel kernel_v8(prog, "compute_hp_v8");
47      kernel_v8.setArg(0, outputBuf), kernel_v8.setArg(1, A);
48  
49      cl::Kernel kernel_v16(prog, "compute_hp_v16");
50      kernel_v16.setArg(0, outputBuf), kernel_v16.setArg(1, A);
51  
52      ///////////////////////////////////////////////////////////////////////////
53      // Vector width 1
54      if (!forceTest || strcmp(specifiedTestName, "half") == 0)
55      {
56        log->print(TAB TAB TAB "half   : ");
57  
58        workPerWI = 4096; // Indicates flops executed per work-item
59  
60        timed = run_kernel(queue, kernel_v1, globalSize, localSize, iters);
61  
62        gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
63  
64        log->print(gflops);
65        log->print(NEWLINE);
66        log->xmlRecord("half", gflops);
67      }
68      ///////////////////////////////////////////////////////////////////////////
69  
70      // Vector width 2
71      if (!forceTest || strcmp(specifiedTestName, "half2") == 0)
72      {
73        log->print(TAB TAB TAB "half2  : ");
74  
75        workPerWI = 4096;
76  
77        timed = run_kernel(queue, kernel_v2, globalSize, localSize, iters);
78  
79        gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
80  
81        log->print(gflops);
82        log->print(NEWLINE);
83        log->xmlRecord("half2", gflops);
84      }
85      ///////////////////////////////////////////////////////////////////////////
86  
87      // Vector width 4
88      if (!forceTest || strcmp(specifiedTestName, "half4") == 0)
89      {
90        log->print(TAB TAB TAB "half4  : ");
91  
92        workPerWI = 4096;
93  
94        timed = run_kernel(queue, kernel_v4, globalSize, localSize, iters);
95  
96        gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
97  
98        log->print(gflops);
99        log->print(NEWLINE);
100        log->xmlRecord("half4", gflops);
101      }
102      ///////////////////////////////////////////////////////////////////////////
103  
104      // Vector width 8
105      if (!forceTest || strcmp(specifiedTestName, "half8") == 0)
106      {
107        log->print(TAB TAB TAB "half8  : ");
108        workPerWI = 4096;
109  
110        timed = run_kernel(queue, kernel_v8, globalSize, localSize, iters);
111  
112        gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
113  
114        log->print(gflops);
115        log->print(NEWLINE);
116        log->xmlRecord("half8", gflops);
117      }
118      ///////////////////////////////////////////////////////////////////////////
119  
120      // Vector width 16
121      if (!forceTest || strcmp(specifiedTestName, "half16") == 0)
122      {
123        log->print(TAB TAB TAB "half16 : ");
124  
125        workPerWI = 4096;
126  
127        timed = run_kernel(queue, kernel_v16, globalSize, localSize, iters);
128  
129        gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
130  
131        log->print(gflops);
132        log->print(NEWLINE);
133        log->xmlRecord("half16", gflops);
134      }
135      ///////////////////////////////////////////////////////////////////////////
136      log->xmlCloseTag(); // half_precision_compute
137    }
138    catch (cl::Error &error)
139    {
140      stringstream ss;
141      ss << error.what() << " (" << error.err() << ")" NEWLINE
142         << TAB TAB TAB "Tests skipped" NEWLINE;
143      log->print(ss.str());
144      return -1;
145    }
146  
147    return 0;
148  }
149