xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/fft.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 
5 namespace torch {
6 namespace fft {
7 
8 /// Computes the 1 dimensional fast Fourier transform over a given dimension.
9 /// See https://pytorch.org/docs/main/fft.html#torch.fft.fft.
10 ///
11 /// Example:
12 /// ```
13 /// auto t = torch::randn(128, dtype=kComplexDouble);
14 /// torch::fft::fft(t);
15 /// ```
16 inline Tensor fft(
17     const Tensor& self,
18     std::optional<SymInt> n = std::nullopt,
19     int64_t dim = -1,
20     std::optional<c10::string_view> norm = std::nullopt) {
21   return torch::fft_fft_symint(self, n, dim, norm);
22 }
23 
24 /// Computes the 1 dimensional inverse Fourier transform over a given dimension.
25 /// See https://pytorch.org/docs/main/fft.html#torch.fft.ifft.
26 ///
27 /// Example:
28 /// ```
29 /// auto t = torch::randn(128, dtype=kComplexDouble);
30 /// torch::fft::ifft(t);
31 /// ```
32 inline Tensor ifft(
33     const Tensor& self,
34     std::optional<SymInt> n = std::nullopt,
35     int64_t dim = -1,
36     std::optional<c10::string_view> norm = std::nullopt) {
37   return torch::fft_ifft_symint(self, n, dim, norm);
38 }
39 
40 /// Computes the 2-dimensional fast Fourier transform over the given dimensions.
41 /// See https://pytorch.org/docs/main/fft.html#torch.fft.fft2.
42 ///
43 /// Example:
44 /// ```
45 /// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
46 /// torch::fft::fft2(t);
47 /// ```
48 inline Tensor fft2(
49     const Tensor& self,
50     OptionalIntArrayRef s = std::nullopt,
51     IntArrayRef dim = {-2, -1},
52     std::optional<c10::string_view> norm = std::nullopt) {
53   return torch::fft_fft2(self, s, dim, norm);
54 }
55 
56 /// Computes the inverse of torch.fft.fft2
57 /// See https://pytorch.org/docs/main/fft.html#torch.fft.ifft2.
58 ///
59 /// Example:
60 /// ```
61 /// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
62 /// torch::fft::ifft2(t);
63 /// ```
64 inline Tensor ifft2(
65     const Tensor& self,
66     at::OptionalIntArrayRef s = std::nullopt,
67     IntArrayRef dim = {-2, -1},
68     std::optional<c10::string_view> norm = std::nullopt) {
69   return torch::fft_ifft2(self, s, dim, norm);
70 }
71 
72 /// Computes the N dimensional fast Fourier transform over given dimensions.
73 /// See https://pytorch.org/docs/main/fft.html#torch.fft.fftn.
74 ///
75 /// Example:
76 /// ```
77 /// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
78 /// torch::fft::fftn(t);
79 /// ```
80 inline Tensor fftn(
81     const Tensor& self,
82     at::OptionalIntArrayRef s = std::nullopt,
83     at::OptionalIntArrayRef dim = std::nullopt,
84     std::optional<c10::string_view> norm = std::nullopt) {
85   return torch::fft_fftn(self, s, dim, norm);
86 }
87 
88 /// Computes the N dimensional fast Fourier transform over given dimensions.
89 /// See https://pytorch.org/docs/main/fft.html#torch.fft.ifftn.
90 ///
91 /// Example:
92 /// ```
93 /// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
94 /// torch::fft::ifftn(t);
95 /// ```
96 inline Tensor ifftn(
97     const Tensor& self,
98     at::OptionalIntArrayRef s = std::nullopt,
99     at::OptionalIntArrayRef dim = std::nullopt,
100     std::optional<c10::string_view> norm = std::nullopt) {
101   return torch::fft_ifftn(self, s, dim, norm);
102 }
103 
104 /// Computes the 1 dimensional FFT of real input with onesided Hermitian output.
105 /// See https://pytorch.org/docs/main/fft.html#torch.fft.rfft.
106 ///
107 /// Example:
108 /// ```
109 /// auto t = torch::randn(128);
110 /// auto T = torch::fft::rfft(t);
111 /// assert(T.is_complex() && T.numel() == 128 / 2 + 1);
112 /// ```
113 inline Tensor rfft(
114     const Tensor& self,
115     std::optional<SymInt> n = std::nullopt,
116     int64_t dim = -1,
117     std::optional<c10::string_view> norm = std::nullopt) {
118   return torch::fft_rfft_symint(self, n, dim, norm);
119 }
120 
121 /// Computes the inverse of torch.fft.rfft
122 ///
123 /// The input is a onesided Hermitian Fourier domain signal, with real-valued
124 /// output. See https://pytorch.org/docs/main/fft.html#torch.fft.irfft
125 ///
126 /// Example:
127 /// ```
128 /// auto T = torch::randn(128 / 2 + 1, torch::kComplexDouble);
129 /// auto t = torch::fft::irfft(t, /*n=*/128);
130 /// assert(t.is_floating_point() && T.numel() == 128);
131 /// ```
132 inline Tensor irfft(
133     const Tensor& self,
134     std::optional<SymInt> n = std::nullopt,
135     int64_t dim = -1,
136     std::optional<c10::string_view> norm = std::nullopt) {
137   return torch::fft_irfft_symint(self, n, dim, norm);
138 }
139 
140 /// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian
141 /// output. See https://pytorch.org/docs/main/fft.html#torch.fft.rfft2
142 ///
143 /// Example:
144 /// ```
145 /// auto t = torch::randn({128, 128}, dtype=kDouble);
146 /// torch::fft::rfft2(t);
147 /// ```
148 inline Tensor rfft2(
149     const Tensor& self,
150     at::OptionalIntArrayRef s = std::nullopt,
151     IntArrayRef dim = {-2, -1},
152     std::optional<c10::string_view> norm = std::nullopt) {
153   return torch::fft_rfft2(self, s, dim, norm);
154 }
155 
156 /// Computes the inverse of torch.fft.rfft2.
157 /// See https://pytorch.org/docs/main/fft.html#torch.fft.irfft2.
158 ///
159 /// Example:
160 /// ```
161 /// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
162 /// torch::fft::irfft2(t);
163 /// ```
164 inline Tensor irfft2(
165     const Tensor& self,
166     at::OptionalIntArrayRef s = std::nullopt,
167     IntArrayRef dim = {-2, -1},
168     std::optional<c10::string_view> norm = std::nullopt) {
169   return torch::fft_irfft2(self, s, dim, norm);
170 }
171 
172 /// Computes the N dimensional FFT of real input with onesided Hermitian output.
173 /// See https://pytorch.org/docs/main/fft.html#torch.fft.rfftn
174 ///
175 /// Example:
176 /// ```
177 /// auto t = torch::randn({128, 128}, dtype=kDouble);
178 /// torch::fft::rfftn(t);
179 /// ```
180 inline Tensor rfftn(
181     const Tensor& self,
182     at::OptionalIntArrayRef s = std::nullopt,
183     at::OptionalIntArrayRef dim = std::nullopt,
184     std::optional<c10::string_view> norm = std::nullopt) {
185   return torch::fft_rfftn(self, s, dim, norm);
186 }
187 
188 /// Computes the inverse of torch.fft.rfftn.
189 /// See https://pytorch.org/docs/main/fft.html#torch.fft.irfftn.
190 ///
191 /// Example:
192 /// ```
193 /// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
194 /// torch::fft::irfftn(t);
195 /// ```
196 inline Tensor irfftn(
197     const Tensor& self,
198     at::OptionalIntArrayRef s = std::nullopt,
199     at::OptionalIntArrayRef dim = std::nullopt,
200     std::optional<c10::string_view> norm = std::nullopt) {
201   return torch::fft_irfftn(self, s, dim, norm);
202 }
203 
204 /// Computes the 1 dimensional FFT of a onesided Hermitian signal
205 ///
206 /// The input represents a Hermitian symmetric time domain signal. The returned
207 /// Fourier domain representation of such a signal is a real-valued. See
208 /// https://pytorch.org/docs/main/fft.html#torch.fft.hfft
209 ///
210 /// Example:
211 /// ```
212 /// auto t = torch::randn(128 / 2 + 1, torch::kComplexDouble);
213 /// auto T = torch::fft::hfft(t, /*n=*/128);
214 /// assert(T.is_floating_point() && T.numel() == 128);
215 /// ```
216 inline Tensor hfft(
217     const Tensor& self,
218     std::optional<SymInt> n = std::nullopt,
219     int64_t dim = -1,
220     std::optional<c10::string_view> norm = std::nullopt) {
221   return torch::fft_hfft_symint(self, n, dim, norm);
222 }
223 
224 /// Computes the inverse FFT of a real-valued Fourier domain signal.
225 ///
226 /// The output is a onesided representation of the Hermitian symmetric time
227 /// domain signal. See https://pytorch.org/docs/main/fft.html#torch.fft.ihfft.
228 ///
229 /// Example:
230 /// ```
231 /// auto T = torch::randn(128, torch::kDouble);
232 /// auto t = torch::fft::ihfft(T);
233 /// assert(t.is_complex() && T.numel() == 128 / 2 + 1);
234 /// ```
235 inline Tensor ihfft(
236     const Tensor& self,
237     std::optional<SymInt> n = std::nullopt,
238     int64_t dim = -1,
239     std::optional<c10::string_view> norm = std::nullopt) {
240   return torch::fft_ihfft_symint(self, n, dim, norm);
241 }
242 
243 /// Computes the 2-dimensional FFT of a Hermitian symmetric input signal.
244 ///
245 /// The input is a onesided representation of the Hermitian symmetric time
246 /// domain signal. See https://pytorch.org/docs/main/fft.html#torch.fft.hfft2.
247 ///
248 /// Example:
249 /// ```
250 /// auto t = torch::randn({128, 65}, torch::kComplexDouble);
251 /// auto T = torch::fft::hfft2(t, /*s=*/{128, 128});
252 /// assert(T.is_floating_point() && T.numel() == 128 * 128);
253 /// ```
254 inline Tensor hfft2(
255     const Tensor& self,
256     at::OptionalIntArrayRef s = std::nullopt,
257     IntArrayRef dim = {-2, -1},
258     std::optional<c10::string_view> norm = std::nullopt) {
259   return torch::fft_hfft2(self, s, dim, norm);
260 }
261 
262 /// Computes the 2-dimensional IFFT of a real input signal.
263 ///
264 /// The output is a onesided representation of the Hermitian symmetric time
265 /// domain signal. See
266 /// https://pytorch.org/docs/main/fft.html#torch.fft.ihfft2.
267 ///
268 /// Example:
269 /// ```
270 /// auto T = torch::randn({128, 128}, torch::kDouble);
271 /// auto t = torch::fft::hfft2(T);
272 /// assert(t.is_complex() && t.size(1) == 65);
273 /// ```
274 inline Tensor ihfft2(
275     const Tensor& self,
276     at::OptionalIntArrayRef s = std::nullopt,
277     IntArrayRef dim = {-2, -1},
278     std::optional<c10::string_view> norm = std::nullopt) {
279   return torch::fft_ihfft2(self, s, dim, norm);
280 }
281 
282 /// Computes the N-dimensional FFT of a Hermitian symmetric input signal.
283 ///
284 /// The input is a onesided representation of the Hermitian symmetric time
285 /// domain signal. See https://pytorch.org/docs/main/fft.html#torch.fft.hfftn.
286 ///
287 /// Example:
288 /// ```
289 /// auto t = torch::randn({128, 65}, torch::kComplexDouble);
290 /// auto T = torch::fft::hfftn(t, /*s=*/{128, 128});
291 /// assert(T.is_floating_point() && T.numel() == 128 * 128);
292 /// ```
293 inline Tensor hfftn(
294     const Tensor& self,
295     at::OptionalIntArrayRef s = std::nullopt,
296     IntArrayRef dim = {-2, -1},
297     std::optional<c10::string_view> norm = std::nullopt) {
298   return torch::fft_hfftn(self, s, dim, norm);
299 }
300 
301 /// Computes the N-dimensional IFFT of a real input signal.
302 ///
303 /// The output is a onesided representation of the Hermitian symmetric time
304 /// domain signal. See
305 /// https://pytorch.org/docs/main/fft.html#torch.fft.ihfftn.
306 ///
307 /// Example:
308 /// ```
309 /// auto T = torch::randn({128, 128}, torch::kDouble);
310 /// auto t = torch::fft::hfft2(T);
311 /// assert(t.is_complex() && t.size(1) == 65);
312 /// ```
313 inline Tensor ihfftn(
314     const Tensor& self,
315     at::OptionalIntArrayRef s = std::nullopt,
316     IntArrayRef dim = {-2, -1},
317     std::optional<c10::string_view> norm = std::nullopt) {
318   return torch::fft_ihfftn(self, s, dim, norm);
319 }
320 
321 /// Computes the discrete Fourier Transform sample frequencies for a signal of
322 /// size n.
323 ///
324 /// See https://pytorch.org/docs/main/fft.html#torch.fft.fftfreq
325 ///
326 /// Example:
327 /// ```
328 /// auto frequencies = torch::fft::fftfreq(128, torch::kDouble);
329 /// ```
330 inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options = {}) {
331   return torch::fft_fftfreq(n, d, options);
332 }
333 
334 inline Tensor fftfreq(int64_t n, const TensorOptions& options = {}) {
335   return torch::fft_fftfreq(n, /*d=*/1.0, options);
336 }
337 
338 /// Computes the sample frequencies for torch.fft.rfft with a signal of size n.
339 ///
340 /// Like torch.fft.rfft, only the positive frequencies are included.
341 /// See https://pytorch.org/docs/main/fft.html#torch.fft.rfftfreq
342 ///
343 /// Example:
344 /// ```
345 /// auto frequencies = torch::fft::rfftfreq(128, torch::kDouble);
346 /// ```
rfftfreq(int64_t n,double d,const TensorOptions & options)347 inline Tensor rfftfreq(int64_t n, double d, const TensorOptions& options) {
348   return torch::fft_rfftfreq(n, d, options);
349 }
350 
rfftfreq(int64_t n,const TensorOptions & options)351 inline Tensor rfftfreq(int64_t n, const TensorOptions& options) {
352   return torch::fft_rfftfreq(n, /*d=*/1.0, options);
353 }
354 
355 /// Reorders n-dimensional FFT output to have negative frequency terms first, by
356 /// a torch.roll operation.
357 ///
358 /// See https://pytorch.org/docs/main/fft.html#torch.fft.fftshift
359 ///
360 /// Example:
361 /// ```
362 /// auto x = torch::randn({127, 4});
363 /// auto centred_fft = torch::fft::fftshift(torch::fft::fftn(x));
364 /// ```
365 inline Tensor fftshift(
366     const Tensor& x,
367     at::OptionalIntArrayRef dim = std::nullopt) {
368   return torch::fft_fftshift(x, dim);
369 }
370 
371 /// Inverse of torch.fft.fftshift
372 ///
373 /// See https://pytorch.org/docs/main/fft.html#torch.fft.ifftshift
374 ///
375 /// Example:
376 /// ```
377 /// auto x = torch::randn({127, 4});
378 /// auto shift = torch::fft::fftshift(x)
379 /// auto unshift = torch::fft::ifftshift(shift);
380 /// assert(torch::allclose(x, unshift));
381 /// ```
382 inline Tensor ifftshift(
383     const Tensor& x,
384     at::OptionalIntArrayRef dim = std::nullopt) {
385   return torch::fft_ifftshift(x, dim);
386 }
387 
388 } // namespace fft
389 } // namespace torch
390