1# Eric Jang originally wrote an implementation of MAML in JAX 2# (https://github.com/ericjang/maml-jax). 3# We translated his implementation from JAX to PyTorch. 4 5import math 6 7import matplotlib as mpl 8import matplotlib.pyplot as plt 9import numpy as np 10 11import torch 12from torch.nn import functional as F 13 14 15mpl.use("Agg") 16 17 18def net(x, params): 19 x = F.linear(x, params[0], params[1]) 20 x = F.relu(x) 21 22 x = F.linear(x, params[2], params[3]) 23 x = F.relu(x) 24 25 x = F.linear(x, params[4], params[5]) 26 return x 27 28 29params = [ 30 torch.Tensor(40, 1).uniform_(-1.0, 1.0).requires_grad_(), 31 torch.Tensor(40).zero_().requires_grad_(), 32 torch.Tensor(40, 40) 33 .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40)) 34 .requires_grad_(), 35 torch.Tensor(40).zero_().requires_grad_(), 36 torch.Tensor(1, 40) 37 .uniform_(-1.0 / math.sqrt(40), 1.0 / math.sqrt(40)) 38 .requires_grad_(), 39 torch.Tensor(1).zero_().requires_grad_(), 40] 41 42opt = torch.optim.Adam(params, lr=1e-3) 43alpha = 0.1 44 45K = 20 46losses = [] 47num_tasks = 4 48 49 50def sample_tasks(outer_batch_size, inner_batch_size): 51 # Select amplitude and phase for the task 52 As = [] 53 phases = [] 54 for _ in range(outer_batch_size): 55 As.append(np.random.uniform(low=0.1, high=0.5)) 56 phases.append(np.random.uniform(low=0.0, high=np.pi)) 57 58 def get_batch(): 59 xs, ys = [], [] 60 for A, phase in zip(As, phases): 61 x = np.random.uniform(low=-5.0, high=5.0, size=(inner_batch_size, 1)) 62 y = A * np.sin(x + phase) 63 xs.append(x) 64 ys.append(y) 65 return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) 66 67 x1, y1 = get_batch() 68 x2, y2 = get_batch() 69 return x1, y1, x2, y2 70 71 72for it in range(20000): 73 loss2 = 0.0 74 opt.zero_grad() 75 76 def get_loss_for_task(x1, y1, x2, y2): 77 f = net(x1, params) 78 loss = F.mse_loss(f, y1) 79 80 # create_graph=True because computing grads here is part of the forward pass. 81 # We want to differentiate through the SGD update steps and get higher order 82 # derivatives in the backward pass. 83 grads = torch.autograd.grad(loss, params, create_graph=True) 84 new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] 85 86 v_f = net(x2, new_params) 87 return F.mse_loss(v_f, y2) 88 89 task = sample_tasks(num_tasks, K) 90 inner_losses = [ 91 get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) 92 for i in range(num_tasks) 93 ] 94 loss2 = sum(inner_losses) / len(inner_losses) 95 loss2.backward() 96 97 opt.step() 98 99 if it % 100 == 0: 100 print("Iteration %d -- Outer Loss: %.4f" % (it, loss2)) 101 losses.append(loss2.detach()) 102 103t_A = torch.tensor(0.0).uniform_(0.1, 0.5) 104t_b = torch.tensor(0.0).uniform_(0.0, math.pi) 105 106t_x = torch.empty(4, 1).uniform_(-5, 5) 107t_y = t_A * torch.sin(t_x + t_b) 108 109opt.zero_grad() 110 111t_params = params 112for k in range(5): 113 t_f = net(t_x, t_params) 114 t_loss = F.l1_loss(t_f, t_y) 115 116 grads = torch.autograd.grad(t_loss, t_params, create_graph=True) 117 t_params = [(t_params[i] - alpha * grads[i]) for i in range(len(params))] 118 119 120test_x = torch.arange(-2 * math.pi, 2 * math.pi, step=0.01).unsqueeze(1) 121test_y = t_A * torch.sin(test_x + t_b) 122 123test_f = net(test_x, t_params) 124 125plt.plot(test_x.data.numpy(), test_y.data.numpy(), label="sin(x)") 126plt.plot(test_x.data.numpy(), test_f.data.numpy(), label="net(x)") 127plt.plot(t_x.data.numpy(), t_y.data.numpy(), "o", label="Examples") 128plt.legend() 129plt.savefig("maml-sine.png") 130plt.figure() 131plt.plot(np.convolve(losses, [0.05] * 20)) 132plt.savefig("losses.png") 133