class MST { /** * Minumum Spanning Tree * */ private: vector _pre; // pre-node vector _size; // size of node /*! * @brief : finding function of union set * @param [x] : node index * @retval : parent node */ int find(int x) { if (_pre[x] == x) return x; _pre[x] = find(_pre[x]); return _pre[x]; } public: /*! * @brief : prim minimum spanning tree algorithm * @param [num_nodes] : number of nodes * @param [connections] : inter-node connection distance [start£¬end£¬distance] * @retval : minimum weighted-sum */ int prim(int num_nodes, vector>& connections) { vector>> edges(n); for (size_t i = 0; i < connections.size(); i++) { int city_a = connections[i][0], city_b = connections[i][1]; int cost = connections[i][2]; edges[city_a].push_back(make_pair(city_b, cost)); edges[city_b].push_back(make_pair(city_a, cost)); } set intree; // set of visited node vector> out_edges; // external edge out_edges.push_back(make_pair(0, 0)); // target node int ans = 0; // iterate over all outward expanding edges until all nodes are visited while (out_edges.size() != 0 && intree.size() != num_nodes) { // find the edge with minimal weight vector>::iterator iter = min_element(out_edges.begin(), out_edges.end(), [&](pair& elem1, pairelem2) { return elem1.second < elem2.second; }); pair out_edge = *iter; out_edges.erase(iter); // add unvisited node if (intree.find(out_edge.first) == intree.end()) { intree.insert(out_edge.first); ans += out_edge.second; for (pair edge : edges[out_edge.first]) { out_edges.push_back(make_pair(edge.first, edge.second)); } } } if (intree.size() != num_nodes) return -1; // not exist if two nodes is not connected return ans; } /*! * @brief : Kruskal MST algorithm * @param [num_nodes] : Number of nodes * @param [connections] : Inter-node connection distance [start£¬end£¬distance] * @retval : Minimum weighted-sum */ int kruskal(int numNodes, vector>& connections) _pre.resize(numNodes), _size.resize(numNodes, 1); iota(_pre.begin(), _pre.end(), 0); // sort with the distance sort(connections.begin(), connections.end(), [&](vector& elem1, vector& elem2) { return elem1.at(2) < elem2.at(2); }); int ans = 0; // minimum weighted-sum int edge_count = 0; // number of visited nodes for (size_t i = 0; i < connections.size(); i++) { int x = find(connections[i][0]), y = find(connections[i][1]); // Union find set if (x != y) { if (_size[x] > _size[y]) { swap(x, y); } _pre[x] = y; _size[y] += _size[x]; ans += connections[i][2]; edge_count++; if (edge_count == numNodes - 1) { return ans; } } } return -1; // not exist if two nodes is not connected } };