1 #pragma once
2 #include <torch/csrc/jit/frontend/error_report.h>
3 #include <torch/csrc/jit/frontend/lexer.h>
4 #include <optional>
5
6 namespace torch::jit {
7
isCharCount(char c,const std::string & str,size_t start,int len)8 inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
9 // count checks from [start, start + len)
10 return start + len <= str.size() &&
11 std::count(
12 str.begin() + static_cast<ptrdiff_t>(start),
13 str.begin() + static_cast<ptrdiff_t>(start + len),
14 c) == len;
15 }
16
parseOctal(const std::string & str,size_t pos)17 inline std::optional<char> parseOctal(const std::string& str, size_t pos) {
18 //\xxx where x are 0-7
19 if (pos + 3 >= str.size())
20 return std::nullopt;
21 size_t c = 0;
22 for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) {
23 auto d = str[pos + i];
24 if (d < '0' || d > '7')
25 return std::nullopt;
26 c += b * (d - '0');
27 }
28 if (c >= 256)
29 return std::nullopt;
30 return c;
31 }
32
parseStringLiteral(const SourceRange & range,const std::string & str)33 inline std::string parseStringLiteral(
34 const SourceRange& range,
35 const std::string& str) {
36 size_t quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1;
37 auto ret_str = str.substr(quote_len, str.size() - quote_len * 2);
38 size_t pos = ret_str.find('\\');
39 while (pos != std::string::npos) {
40 // invariant: pos has to escape a character because it is a valid string
41 char c = ret_str[pos + 1];
42 size_t to_erase = 2;
43 switch (ret_str[pos + 1]) {
44 case '\\':
45 case '\'':
46 case '\"':
47 case '\n':
48 break;
49 case 'a':
50 c = '\a';
51 break;
52 case 'b':
53 c = '\b';
54 break;
55 case 'f':
56 c = '\f';
57 break;
58 case 'n':
59 c = '\n';
60 break;
61 case 'v':
62 c = '\v';
63 break;
64 case 't':
65 c = '\t';
66 break;
67 case 'x':
68 throw(ErrorReport(range) << "unsupported hex specifier");
69 case 'u':
70 case 'U':
71 throw(ErrorReport(range) << "unsupported unicode specifier");
72 default:
73 // octal value in format \nnn, n is [0-7]
74 if (auto v = parseOctal(ret_str, pos)) {
75 to_erase = 4;
76 c = *v;
77 } else {
78 throw(ErrorReport(range) << " ill formed octal specifier");
79 }
80 }
81 ret_str.replace(pos, to_erase, /* num copies */ 1, c);
82 pos = ret_str.find('\\', pos + 1);
83 }
84 return ret_str;
85 }
86
87 } // namespace torch::jit
88