xref: /aosp_15_r20/external/pytorch/c10/util/NetworkFlow.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #pragma once
2  
3  #include <c10/macros/Macros.h>
4  
5  #include <string>
6  #include <vector>
7  
8  /**
9   * This file provides a network flow implementation.
10   * https://en.wikipedia.org/wiki/Flow_network
11   *
12   * It aims to mirror some of the behavior of networkx, which is/was used by
13   * functorch partitioners for splitting the graph into a forward and backward
14   * graph.
15   */
16  
17  namespace c10 {
18  
19  enum class C10_API_ENUM MinCutStatus {
20    SUCCESS = 0,
21    UNBOUNDED = 1,
22    OVERFLOW_INF = 2,
23    INVALID = 3,
24  };
25  
26  struct MinCutResult {
27    MinCutStatus status;
28    int64_t max_flow;
29    std::vector<std::string> reachable;
30    std::vector<std::string> unreachable;
31  };
32  
33  // Modeled after networkx implementation
34  class C10_API NetworkFlowGraph {
35   public:
36    // selected such that INF + INF is < INT64_MAX
37    constexpr static int64_t INF = (1LL << 62) - 1;
38  
39    struct Edge {
40      std::string source, dest;
41      int64_t capacity;
42    };
43  
44    MinCutStatus add_edge(
45        const std::string& source,
46        const std::string& dest,
47        int64_t capacity = 1);
48  
49    MinCutResult minimum_cut(const std::string& s, const std::string& t) const;
50  
51    std::vector<Edge> edges;
52  };
53  
54  } // namespace c10
55