xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/parse_string_literal.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/jit/frontend/error_report.h>
3 #include <torch/csrc/jit/frontend/lexer.h>
4 #include <optional>
5 
6 namespace torch::jit {
7 
isCharCount(char c,const std::string & str,size_t start,int len)8 inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
9   // count checks from [start, start + len)
10   return start + len <= str.size() &&
11       std::count(
12           str.begin() + static_cast<ptrdiff_t>(start),
13           str.begin() + static_cast<ptrdiff_t>(start + len),
14           c) == len;
15 }
16 
parseOctal(const std::string & str,size_t pos)17 inline std::optional<char> parseOctal(const std::string& str, size_t pos) {
18   //\xxx where x are 0-7
19   if (pos + 3 >= str.size())
20     return std::nullopt;
21   size_t c = 0;
22   for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) {
23     auto d = str[pos + i];
24     if (d < '0' || d > '7')
25       return std::nullopt;
26     c += b * (d - '0');
27   }
28   if (c >= 256)
29     return std::nullopt;
30   return c;
31 }
32 
parseStringLiteral(const SourceRange & range,const std::string & str)33 inline std::string parseStringLiteral(
34     const SourceRange& range,
35     const std::string& str) {
36   size_t quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1;
37   auto ret_str = str.substr(quote_len, str.size() - quote_len * 2);
38   size_t pos = ret_str.find('\\');
39   while (pos != std::string::npos) {
40     // invariant: pos has to escape a character because it is a valid string
41     char c = ret_str[pos + 1];
42     size_t to_erase = 2;
43     switch (ret_str[pos + 1]) {
44       case '\\':
45       case '\'':
46       case '\"':
47       case '\n':
48         break;
49       case 'a':
50         c = '\a';
51         break;
52       case 'b':
53         c = '\b';
54         break;
55       case 'f':
56         c = '\f';
57         break;
58       case 'n':
59         c = '\n';
60         break;
61       case 'v':
62         c = '\v';
63         break;
64       case 't':
65         c = '\t';
66         break;
67       case 'x':
68         throw(ErrorReport(range) << "unsupported hex specifier");
69       case 'u':
70       case 'U':
71         throw(ErrorReport(range) << "unsupported unicode specifier");
72       default:
73         // octal value in format \nnn, n is [0-7]
74         if (auto v = parseOctal(ret_str, pos)) {
75           to_erase = 4;
76           c = *v;
77         } else {
78           throw(ErrorReport(range) << " ill formed octal specifier");
79         }
80     }
81     ret_str.replace(pos, to_erase, /* num copies */ 1, c);
82     pos = ret_str.find('\\', pos + 1);
83   }
84   return ret_str;
85 }
86 
87 } // namespace torch::jit
88