xref: /aosp_15_r20/external/pytorch/functorch/examples/maml_regression/evjang.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Eric Jang originally wrote an implementation of MAML in JAX
2# (https://github.com/ericjang/maml-jax).
3# We translated his implementation from JAX to PyTorch.
4
5import math
6
7import matplotlib as mpl
8import matplotlib.pyplot as plt
9import numpy as np
10
11import torch
12from torch.nn import functional as F
13
14
15mpl.use("Agg")
16
17
18def net(x, params):
19    x = F.linear(x, params[0], params[1])
20    x = F.relu(x)
21
22    x = F.linear(x, params[2], params[3])
23    x = F.relu(x)
24
25    x = F.linear(x, params[4], params[5])
26    return x
27
28
29params = [
30    torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(),
31    torch.Tensor(40).zero_().requires_grad_(),
32    torch.Tensor(40, 40)
33    .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
34    .requires_grad_(),
35    torch.Tensor(40).zero_().requires_grad_(),
36    torch.Tensor(1, 40)
37    .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40))
38    .requires_grad_(),
39    torch.Tensor(1).zero_().requires_grad_(),
40]
41
42opt = torch.optim.Adam(params, lr=1e-3)
43alpha = 0.1
44
45K = 20
46losses = []
47num_tasks = 4
48
49
50def sample_tasks(outer_batch_size, inner_batch_size):
51    # Select amplitude and phase for the task
52    As = []
53    phases = []
54    for _ in range(outer_batch_size):
55        As.append(np.random.uniform(low=0.1, high=0.5))
56        phases.append(np.random.uniform(low=0.0, high=np.pi))
57
58    def get_batch():
59        xs, ys = [], []
60        for A, phase in zip(As, phases):
61            x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1))
62            y = A * np.sin(x + phase)
63            xs.append(x)
64            ys.append(y)
65        return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float)
66
67    x1, y1 = get_batch()
68    x2, y2 = get_batch()
69    return x1, y1, x2, y2
70
71
72for it in range(20000):
73    loss2 = 0.0
74    opt.zero_grad()
75
76    def get_loss_for_task(x1, y1, x2, y2):
77        f = net(x1, params)
78        loss = F.mse_loss(f, y1)
79
80        # create_graph=True because computing grads here is part of the forward pass.
81        # We want to differentiate through the SGD update steps and get higher order
82        # derivatives in the backward pass.
83        grads = torch.autograd.grad(loss, params, create_graph=True)
84        new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))]
85
86        v_f = net(x2, new_params)
87        return F.mse_loss(v_f, y2)
88
89    task = sample_tasks(num_tasks, K)
90    inner_losses = [
91        get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i])
92        for i in range(num_tasks)
93    ]
94    loss2 = sum(inner_losses) / len(inner_losses)
95    loss2.backward()
96
97    opt.step()
98
99    if it % 100 == 0:
100        print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
101    losses.append(loss2.detach())
102
103t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
104t_b = torch.tensor(0.0).uniform_(0.0, math.pi)
105
106t_x = torch.empty(4, 1).uniform_(-5, 5)
107t_y = t_A * torch.sin(t_x + t_b)
108
109opt.zero_grad()
110
111t_params = params
112for k in range(5):
113    t_f = net(t_x, t_params)
114    t_loss = F.l1_loss(t_f, t_y)
115
116    grads = torch.autograd.grad(t_loss, t_params, create_graph=True)
117    t_params = [(t_params[i] - alpha * grads[i]) for i in range(len(params))]
118
119
120test_x = torch.arange(-2 * math.pi, 2 * math.pi, step=0.01).unsqueeze(1)
121test_y = t_A * torch.sin(test_x + t_b)
122
123test_f = net(test_x, t_params)
124
125plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)")
126plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)")
127plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples")
128plt.legend()
129plt.savefig("maml-sine.png")
130plt.figure()
131plt.plot(np.convolve(losses, [0.05] * 20))
132plt.savefig("losses.png")
133