/*
 * Copyright (C) 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define LOG_TAG "ModelUtils"

#include "ModelUtils.h"

#include <android-base/logging.h>

#include <algorithm>
#include <numeric>
#include <unordered_set>
#include <utility>
#include <vector>

#include "nnapi/TypeUtils.h"
#include "nnapi/Types.h"
#include "nnapi/Validation.h"

namespace android::nn {
namespace {

// Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.:
//   includes = {false, true, true, false, true}
//   returned = {    X,    0,    1,     X,    2}
std::vector<uint32_t> getMapping(const std::vector<bool>& includes) {
    std::vector<uint32_t> mapping;
    mapping.reserve(includes.size());
    std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u,
                                  std::plus<>{}, [](bool included) { return included ? 1u : 0u; });
    return mapping;
}

// Remap indexes in `indexes` by the mapping `mapping`.
// Precondition: indexes != nullptr
void remapIndexes(std::vector<uint32_t>* indexes, const std::vector<uint32_t>& mapping) {
    CHECK(indexes != nullptr);
    for (uint32_t& index : (*indexes)) {
        index = mapping.at(index);
    }
}

// Keep elements from `elements` specified by `elementsToKeep`, removing all other elements.
// Precondition: elements != nullptr
// Precondition: elements->size() == elementsToKeep.size()
template <typename Type>
void keepSelectedElements(std::vector<Type>* elements, const std::vector<bool>& elementsToKeep) {
    CHECK(elements != nullptr);
    CHECK_EQ(elements->size(), elementsToKeep.size());

    size_t elementsCopied = 0;
    for (size_t i = 0; i < elementsToKeep.size(); ++i) {
        if (elementsToKeep[i]) {
            if (elementsCopied != i) {
                (*elements)[elementsCopied] = std::move((*elements)[i]);
            }
            elementsCopied++;
        }
    }
    elements->resize(elementsCopied);
}

// Find which operands in model.main.operands are read or written by model.main.operations and
// model.main.inputIndexes.
// Postcondition: returned.size() == model.main.operands.size()
std::vector<bool> identifyUsedOperands(const Model& model) {
    std::vector<bool> used(model.main.operands.size(), false);
    auto markUsed = [&used](const std::vector<uint32_t>& indexes) {
        std::for_each(indexes.begin(), indexes.end(),
                      [&used](uint32_t index) { used.at(index) = true; });
    };
    for (const auto& operation : model.main.operations) {
        markUsed(operation.inputs);
        markUsed(operation.outputs);
    }
    markUsed(model.main.inputIndexes);
    CHECK_EQ(used.size(), model.main.operands.size());
    return used;
}

// Forward declaration.
void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
                           std::vector<bool>* used);

// Helper function to find which subgraphs are reachable by `operands`.
// Precondition: used != nullptr
// Precondition: subgraphs.size() == used->size()
void identifyUsedSubgraphs(const std::vector<Operand>& operands,
                           const std::vector<Model::Subgraph>& subgraphs, std::vector<bool>* used) {
    for (const auto& operand : operands) {
        if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
            identifyUsedSubgraphs(operand.location.offset, subgraphs, used);
        }
    }
}

// Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and
// store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is
// processed at most once.
// Precondition: used != nullptr
// Precondition: subgraphs.size() == used->size()
// Precondition: current < subgraphs.size()
void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
                           std::vector<bool>* used) {
    CHECK(used != nullptr);
    CHECK_EQ(subgraphs.size(), used->size());
    CHECK_LT(current, subgraphs.size());

    // If a subgraph was already marked as used, quickly return to avoid redundant processing.
    if ((*used)[current]) {
        return;
    }

    // Mark the current subgraph as used, then process any subgraph it references recursively.
    (*used)[current] = true;
    identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used);
}

// Find which subgraphs are reachable by the main operands of `model`.
// Postcondition: returned.size() == model.referenced.size()
std::vector<bool> identifyUsedSubgraphs(const Model& model) {
    std::vector<bool> used(model.referenced.size(), false);
    identifyUsedSubgraphs(model.main.operands, model.referenced, &used);
    CHECK_EQ(used.size(), model.referenced.size());
    return used;
}

// Helper function to find which pools are used by `subgraph`, and store when a pool is used in
// `used`.
// Precondition: used != nullptr
void identifyUsedPools(const Model::Subgraph& subgraph, std::vector<bool>* used) {
    CHECK(used != nullptr);
    for (const auto& operand : subgraph.operands) {
        if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
            used->at(operand.location.poolIndex) = true;
        }
    }
}

// Find which pools are used by `model`.
// Postcondition: returned.size() == model.pools.size()
std::vector<bool> identifyUsedPools(const Model& model) {
    std::vector<bool> used(model.pools.size(), false);
    identifyUsedPools(model.main, &used);
    for (const auto& subgraph : model.referenced) {
        identifyUsedPools(subgraph, &used);
    }
    CHECK_EQ(used.size(), model.pools.size());
    return used;
}

// Fix the DataLocation in `operand` by either remapping an index or by copying constant data.
// Precondition: operand != nullptr
// Precondition: newOperandValues != nullptr
void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues,
                            const Model::OperandValues& oldOperandValues,
                            const std::vector<uint32_t>& remappedPoolIndex,
                            const std::vector<uint32_t>& remappedSubgraphIndex) {
    CHECK(operand != nullptr);
    CHECK(newOperandValues != nullptr);

    switch (operand->lifetime) {
        case Operand::LifeTime::CONSTANT_COPY: {
            const uint8_t* data = oldOperandValues.data() + operand->location.offset;
            const uint32_t length = operand->location.length;
            operand->location = newOperandValues->append(data, length);
            break;
        }
        case Operand::LifeTime::CONSTANT_REFERENCE:
            operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex);
            break;
        case Operand::LifeTime::SUBGRAPH: {
            uint32_t& subgraphIndex = operand->location.offset;
            subgraphIndex = remappedSubgraphIndex.at(subgraphIndex);
            break;
        }
        case Operand::LifeTime::TEMPORARY_VARIABLE:
        case Operand::LifeTime::SUBGRAPH_INPUT:
        case Operand::LifeTime::SUBGRAPH_OUTPUT:
        case Operand::LifeTime::NO_VALUE:
        case Operand::LifeTime::POINTER:
            break;
    }
}

// Fix all DataLocations in `operands` by either remapping an index or by copying constant data.
// Precondition: operands != nullptr
// Precondition: newOperandValues != nullptr
void fixOperandDataLocations(std::vector<Operand>* operands, Model::OperandValues* newOperandValues,
                             const Model::OperandValues& oldOperandValues,
                             const std::vector<uint32_t>& remappedPoolIndex,
                             const std::vector<uint32_t>& remappedSubgraphIndex) {
    for (Operand& operand : (*operands)) {
        fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex,
                               remappedSubgraphIndex);
    }
}

// Fix all operands' DataLocations in `model` by either remapping an index or by copying constant
// data.
// Precondition: model != nullptr
void fixOperandDataLocations(Model* model, const std::vector<uint32_t>& remappedPoolIndex,
                             const std::vector<uint32_t>& remappedSubgraphIndex) {
    const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{});
    fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues,
                            remappedPoolIndex, remappedSubgraphIndex);
    for (auto& subgraph : model->referenced) {
        fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues,
                                remappedPoolIndex, remappedSubgraphIndex);
    }
}

// Find which extensions are used in `model`.
// Postcondition: returned.size() == model.extensionNameToPrefix.size()
std::vector<bool> identifyUsedExtensions(const Model& model) {
    std::unordered_set<uint16_t> prefixes;
    const auto collectPrefix = [&prefixes](const auto& operandOrOperation) {
        const auto prefix = getExtensionPrefix(static_cast<uint32_t>(operandOrOperation.type));
        constexpr uint16_t kStandardPrefix = 0u;
        if (prefix != kStandardPrefix) {
            prefixes.insert(prefix);
        }
    };
    const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) {
        std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix);
        std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix);
    };

    collectPrefixes(model.main);
    for (const auto& subgraph : model.referenced) {
        collectPrefixes(subgraph);
    }

    std::vector<bool> used;
    used.reserve(model.extensionNameToPrefix.size());
    for (const auto& extension : model.extensionNameToPrefix) {
        used.push_back(prefixes.count(extension.prefix) > 0);
    }
    CHECK_EQ(used.size(), model.extensionNameToPrefix.size());
    return used;
}

}  // anonymous namespace

void removeDeadOperands(Model* model) {
    CHECK(model != nullptr);

    // Keep only the operands which are used.
    const auto operandsUsed = identifyUsedOperands(*model);
    keepSelectedElements(&model->main.operands, operandsUsed);

    // Fix operand indexes.
    const auto mappedOperandIndices = getMapping(operandsUsed);
    for (auto& operation : model->main.operations) {
        remapIndexes(&operation.inputs, mappedOperandIndices);
        remapIndexes(&operation.outputs, mappedOperandIndices);
    }
    remapIndexes(&model->main.inputIndexes, mappedOperandIndices);
    remapIndexes(&model->main.outputIndexes, mappedOperandIndices);

    // Keep only the subgraphs which are used.
    const auto subgraphsUsed = identifyUsedSubgraphs(*model);
    keepSelectedElements(&model->referenced, subgraphsUsed);

    // Keep only the pools which are used.
    const auto poolsUsed = identifyUsedPools(*model);
    keepSelectedElements(&model->pools, poolsUsed);

    // Fix operand locations.
    const auto mappedPoolIndices = getMapping(poolsUsed);
    const auto mappedSubgraphIndices = getMapping(subgraphsUsed);
    fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices);

    // Keep only the extensionNameToPrefixes which are used.
    const auto extensionsUsed = identifyUsedExtensions(*model);
    keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed);
}

}  // namespace android::nn