import argparse

import torch
import torch.nn as nn

from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
from .runner import get_nn_runners


def barf():
    import pdb

    pdb.set_trace()


def assertEqual(tensor, expected, threshold=0.001):
    if isinstance(tensor, (list, tuple)):
        for t, e in zip(tensor, expected):
            assertEqual(t, e)
    else:
        if (tensor - expected).abs().max() > threshold:
            barf()


def filter_requires_grad(tensors):
    return [t for t in tensors if t.requires_grad]


def test_rnns(
    experim_creator,
    control_creator,
    check_grad=True,
    verbose=False,
    seqLength=100,
    numLayers=1,
    inputSize=512,
    hiddenSize=512,
    miniBatch=64,
    device="cuda",
    seed=17,
):
    creator_args = dict(
        seqLength=seqLength,
        numLayers=numLayers,
        inputSize=inputSize,
        hiddenSize=hiddenSize,
        miniBatch=miniBatch,
        device=device,
        seed=seed,
    )

    print("Setting up...")
    control = control_creator(**creator_args)
    experim = experim_creator(**creator_args)

    # Precondition
    assertEqual(experim.inputs, control.inputs)
    assertEqual(experim.params, control.params)

    print("Checking outputs...")
    control_outputs = control.forward(*control.inputs)
    experim_outputs = experim.forward(*experim.inputs)
    assertEqual(experim_outputs, control_outputs)

    print("Checking grads...")
    assert control.backward_setup is not None
    assert experim.backward_setup is not None
    assert control.backward is not None
    assert experim.backward is not None
    control_backward_inputs = control.backward_setup(control_outputs, seed)
    experim_backward_inputs = experim.backward_setup(experim_outputs, seed)

    control.backward(*control_backward_inputs)
    experim.backward(*experim_backward_inputs)

    control_grads = [p.grad for p in control.params]
    experim_grads = [p.grad for p in experim.params]
    assertEqual(experim_grads, control_grads)

    if verbose:
        print(experim.forward.graph_for(*experim.inputs))
    print()


def test_vl_py(**test_args):
    # XXX: This compares vl_py with vl_lstm.
    # It's done this way because those two don't give the same outputs so
    # the result isn't an apples-to-apples comparison right now.
    control_creator = varlen_pytorch_lstm_creator
    name, experim_creator, context = get_nn_runners("vl_py")[0]
    with context():
        print(f"testing {name}...")
        creator_keys = [
            "seqLength",
            "numLayers",
            "inputSize",
            "hiddenSize",
            "miniBatch",
            "device",
            "seed",
        ]
        creator_args = {key: test_args[key] for key in creator_keys}

        print("Setting up...")
        control = control_creator(**creator_args)
        experim = experim_creator(**creator_args)

        # Precondition
        assertEqual(experim.inputs, control.inputs[:2])
        assertEqual(experim.params, control.params)

        print("Checking outputs...")
        control_out, control_hiddens = control.forward(*control.inputs)
        control_hx, control_cx = control_hiddens
        experim_out, experim_hiddens = experim.forward(*experim.inputs)
        experim_hx, experim_cx = experim_hiddens

        experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2)
        assertEqual(experim_padded, control_out)
        assertEqual(torch.cat(experim_hx, dim=1), control_hx)
        assertEqual(torch.cat(experim_cx, dim=1), control_cx)

        print("Checking grads...")
        assert control.backward_setup is not None
        assert experim.backward_setup is not None
        assert control.backward is not None
        assert experim.backward is not None
        control_backward_inputs = control.backward_setup(
            (control_out, control_hiddens), test_args["seed"]
        )
        experim_backward_inputs = experim.backward_setup(
            (experim_out, experim_hiddens), test_args["seed"]
        )

        control.backward(*control_backward_inputs)
        experim.backward(*experim_backward_inputs)

        control_grads = [p.grad for p in control.params]
        experim_grads = [p.grad for p in experim.params]
        assertEqual(experim_grads, control_grads)

        if test_args["verbose"]:
            print(experim.forward.graph_for(*experim.inputs))
        print()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test lstm correctness")

    parser.add_argument("--seqLength", default="100", type=int)
    parser.add_argument("--numLayers", default="1", type=int)
    parser.add_argument("--inputSize", default="512", type=int)
    parser.add_argument("--hiddenSize", default="512", type=int)
    parser.add_argument("--miniBatch", default="64", type=int)
    parser.add_argument("--device", default="cuda", type=str)
    parser.add_argument("--check-grad", "--check_grad", default="True", type=bool)
    parser.add_argument("--variable-lstms", "--variable_lstms", action="store_true")
    parser.add_argument("--seed", default="17", type=int)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--rnns", nargs="*", help="What to run. jit_premul, jit, etc")
    args = parser.parse_args()
    if args.rnns is None:
        args.rnns = ["jit_premul", "jit"]
    print(args)

    if "cuda" in args.device:
        assert torch.cuda.is_available()

    rnn_runners = get_nn_runners(*args.rnns)

    should_test_varlen_lstms = args.variable_lstms
    test_args = vars(args)
    del test_args["rnns"]
    del test_args["variable_lstms"]

    if should_test_varlen_lstms:
        test_vl_py(**test_args)

    for name, creator, context in rnn_runners:
        with context():
            print(f"testing {name}...")
            test_rnns(creator, pytorch_lstm_creator, **test_args)