Files
general-solver/src/solver.cpp

450 lines
9.5 KiB
C++

#include "solver.hpp"
#include "common.hpp"
#include <iostream>
#include <cassert>
#include <algorithm>
#include <stack>
#include <cassert>
using namespace sv;
using std::make_pair;
using std::pair;
using std::cout;
using std::endl;
using std::vector;
struct Node
{
LinSolver solver;
double lower_bound;
double upper_bound;
};
LinSolver::LinSolver() :
obj_(0),
rtn_(LOADED),
cn(0),
bn(1),
sense(0),
vars(nullptr)
{
}
sv::LinSolver::~LinSolver()
{
delete[] vars;
}
sv::LinSolver::LinSolver(const LinSolver& solver)
: vars(nullptr),
cn(solver.cn),
bn(solver.bn),
table(solver.table),
basic(solver.basic),
rtn_(solver.rtn_),
obj_(solver.obj_),
sense(solver.sense)
{
delete[] vars;
vars = nullptr;
if (solver.vars != nullptr && cn > 0) {
vars = new Var[cn];
for (int i = 0; i < cn; i++) {
vars[i] = solver.vars[i];
}
}
}
LinSolver& sv::LinSolver::operator=(const LinSolver& solver)
{
if (this == &solver) {
return *this;
}
delete[] vars; // 释放原有内存
vars = nullptr;
cn = solver.cn;
if (cn > 0) {
vars = new Var[cn];
for (int i = 0; i < cn; i++) {
vars[i] = solver.vars[i]; // 深拷贝
}
}
table = solver.table;
cn = solver.cn, bn = solver.bn;
basic = solver.basic;
rtn_ = solver.rtn_;
obj_ = solver.obj_;
sense = solver.sense;
return *this;
}
Var* LinSolver::addVars(int num, VarType type)
{
Var* old = vars;
vars = new Var[cn + num];
for (int c = cn; c < cn + num; c++) {
vars[c].col = c;
vars[c].type = type;
}
memcpy(vars, old, sizeof(Var) * cn);
delete[] old;
cn += num;
return vars;
}
const Var& sv::LinSolver::getVar(int idx)
{
assert(idx >= 0 && idx < cn); // 添加越界检查
return vars[idx];
}
void LinSolver::addConstr(const Expr& expr, ConstrOper sense, double rhs)
{
if (sense == ConstrOper::LESS_EQUAL) {
bn++;
table.push_back(vector<double>(1, rhs - expr.constant));
table.back().insert(table.back().end(), expr.coeffs.begin(), expr.coeffs.end());
}
else if (sense == ConstrOper::GREATER_EQUAL) {
bn++;
table.push_back(vector<double>(1, expr.constant - rhs));
for (int coeff : expr.coeffs) {
table.back().push_back(-coeff);
}
}
else {
addConstr(expr, ConstrOper::LESS_EQUAL, rhs);
addConstr(expr, ConstrOper::GREATER_EQUAL, rhs);
}
for (int c = table.back().size(); c <= cn; c++) {
table.back().push_back(0);
}
}
void LinSolver::setObjective(Expr obje, int _sense)
{
assert(_sense == 1 || _sense == -1);
if (sense == 0) {
table.insert(table.begin(), obje.coeffs);
table.front().insert(table.front().begin(), -obje.constant);
for (int c = obje.coeffs.size() + 1; c <= cn; c++) {
table.front().push_back(0);
}
}
else {
table.front().front() = -obje.constant;
for (int col = 0; col < cn; col++) {
if (col < obje.coeffs.size()) {
table.front().at(col + 1) = obje.coeffs.at(col);
}
else {
table.front().at(col) = 0;
}
}
}
for (int row = 0; row < table.front().size(); row++) {
table.front().at(row) = _sense * table.front().at(row);
}
sense = _sense;
}
rtn LinSolver::optimize()
{
assert(sense);
ope_table = table;
rtn_ = LOADED;
rtn_ = feasible_solution();
if (rtn_ == LOADED) {
obj_ = _simplex();
}
if (rtn_ == OPTIMAL) {
cn = ope_table.front().size() - bn;
for (int row = 1; row < bn; row++) {
if (basic.at(row - 1) - 1 < cn) {
vars[basic.at(row - 1) - 1].val = ope_table.at(row).front();
}
}
}
return rtn_;
}
void LinSolver::print()
{
for (size_t row = 0; row < ope_table.size(); row++) {
for (size_t col = 0; col < ope_table.front().size(); col++) {
cout << ope_table.at(row).at(col) << "\t";
}
cout << endl;
}
}
Model::Model()
{
}
rtn Model::optimize()
{
solver.optimize();
if (solver.rtn_ != OPTIMAL) {
return solver.rtn_;
}
double global_upper_bound = solver.obj_, global_lower_bound = 0;
std::stack<Node> list_;
Node root_node(solver, 0, solver.obj_);
Node incumbent_node = root_node;
list_.push(root_node);
while (list_.size() && global_upper_bound - global_lower_bound > 1e-10) {
Node current_node = list_.top();
list_.pop();
current_node.solver.optimize();
if (current_node.solver.get(IntAttr::Status) == OPTIMAL) {
int branch_var_index = -1;
for (int i = 0; i < current_node.solver.get(IntAttr::NumVars); i++) {
if (current_node.solver.vars[i].type == VarType::INTEGER) {
if (fabs(int(current_node.solver.vars[i].val) - current_node.solver.vars[i].val) > 1e-10) {
branch_var_index = i;
break;
}
}
}
if (branch_var_index == -1) {
current_node.lower_bound = current_node.solver.obj_;
current_node.upper_bound = current_node.solver.obj_;
if (current_node.lower_bound > global_lower_bound) {
global_lower_bound = current_node.lower_bound;
incumbent_node = current_node;
}
}
else {
if (current_node.upper_bound >= global_lower_bound) {
const Var& branch_var = current_node.solver.getVar(branch_var_index);
int left_var_bound = branch_var.val;
int right_var_bound = branch_var.val + 1;
Node left_node = current_node;
left_node.solver.addConstr(branch_var, ConstrOper::LESS_EQUAL, left_var_bound);
list_.push(left_node);
Node right_node = current_node;
right_node.solver.addConstr(branch_var, ConstrOper::GREATER_EQUAL, right_var_bound);
list_.push(right_node);
}
}
}
}
solver.rtn_ = incumbent_node.solver.rtn_;
solver.obj_ = incumbent_node.solver.obj_;
for (int i = 0; i < solver.cn; i++) {
solver.vars[i].val = incumbent_node.solver.vars[i].val;
}
return solver.rtn_;
}
Var* sv::Model::addVars(int col, VarType type)
{
return solver.addVars(col, type);
}
void sv::Model::addConstr(const Expr& expr, ConstrOper sense, double rhs)
{
return solver.addConstr(expr, sense, rhs);
}
void sv::Model::setObjective(Expr obje, int sense)
{
return solver.setObjective(obje, sense);
}
double sv::Model::get(DoubleAttr attr)
{
return solver.get(attr);
}
int sv::Model::get(IntAttr attr)
{
return solver.get(attr);
}
double LinSolver::get(DoubleAttr attr)
{
return -sense * obj_;
}
int LinSolver::get(IntAttr attr)
{
switch (attr) {
case IntAttr::NumVars:
return cn;
case IntAttr::Status:
return rtn_;
}
return -1;
}
double LinSolver::_simplex()
{
pair<size_t, size_t> t;
while (1) {
rtn_ = _pivot(t);
if (rtn_ == OPTIMAL || rtn_ == UNBOUNDED) {
break;
}
_gaussian(t);
}
return obj_ = ope_table.front().front();
}
rtn LinSolver::feasible_solution()
{
for (int row = 1; row < bn; row++) {
ope_table.front().push_back(0);
for (int col = 1; col < bn; col++) {
ope_table.at(row).push_back(col == row ? 1 : 0);
}
}
cn = ope_table.front().size();
basic.clear();
for (size_t i = 1; i < bn; i++) {
basic.push_back(cn - bn + i);
}
// === 判断初始解是否为可行解 ===
bool initial_feasible = true;
for (int row = 1; row < bn; row++) {
if (ope_table.at(row).front() < 0) {
initial_feasible = false;
break;
}
}
// === 构造初始可行解 ===
if (!initial_feasible) {
vector<double> coeff = ope_table.front();
ope_table.front() = vector<double>(cn, .0);
ope_table.front().push_back(1);
pair<size_t, size_t> t = { -1 ,cn };
for (int row = 1; row < bn; row++) {
ope_table.at(row).push_back(-1);
if (t.first == -1 || ope_table.at(row).front() < ope_table.at(t.first).front()) {
t.first = row;
}
}
_gaussian(t);
if (fabs(_simplex()) > 1e-10) {
return rtn_ = INFEASIBLE;
}
rtn_ = LOADED;
// if the x0 in B, we should pivot it.
auto iter = find(basic.begin(), basic.end(), cn);
if (iter != basic.end()) {
for (int col = 1; col < ope_table.front().size(); col++) {
if (fabs(ope_table.front().at(col)) > 1e-10) {
t = make_pair(iter - basic.begin() + 1, col);
_gaussian(t);
break;
}
}
}
for (int row = 0; row < bn; row++) {
ope_table.at(row).pop_back();
}
// recover the coefficient line
for (int col = 0; col < cn; col++) {
ope_table.front().at(col) = coeff.at(col);
}
for (int row = 1; row <= basic.size(); row++) {
int norm = ope_table.front().at(basic.at(row - 1));
for (int col = 0; col < cn; col++) {
ope_table.front().at(col) -= norm * ope_table.at(row).at(col);
}
}
}
return rtn_;
}
rtn LinSolver::_pivot(pair<size_t, size_t>& p)
{
p = make_pair(0, 0);
double cmin = INT_MAX;
vector<double> coef = ope_table.front();
// === 非主轴元素中找最小值 ===
for (size_t col = 1; col < coef.size(); col++) {
if (cmin > coef.at(col) && find(basic.begin(), basic.end(), col) == basic.end()) {
cmin = coef.at(col);
p.second = col;
}
}
if (cmin >= 0) {
return OPTIMAL;
}
double bmin = INT_MAX;
for (size_t row = 1; row < bn; row++) {
double tmp = ope_table.at(row).front() / ope_table.at(row).at(p.second);
if (ope_table.at(row).at(p.second) > 0 && bmin > tmp) {
bmin = tmp;
p.first = row;
}
}
if (abs(bmin - INT_MAX) < 1e-10) {
return UNBOUNDED;
}
for (auto iter = basic.begin(); iter != basic.end(); iter++) {
if (ope_table.at(p.first).at(*iter) != 0) {
*iter = p.second;
break;
}
}
assert(basic.at(p.first - 1) == p.second);
return PIVOT;
}
void LinSolver::_gaussian(pair<size_t, size_t> p)
{
size_t x = p.first, y = p.second;
// === 主行归一化 ===
double norm = ope_table.at(x).at(y);
for (size_t col = 0; col < ope_table.at(x).size(); col++) {
ope_table.at(x).at(col) /= norm;
}
// === 其余行变换 ===
for (size_t row = 0; row < bn; row++) {
if (row == x) {
continue;
}
if (ope_table.at(row).at(y) != 0) {
double norm = ope_table.at(row).at(y);
for (size_t col = 0; col < ope_table.at(x).size(); col++) {
ope_table.at(row).at(col) = ope_table.at(row).at(col) - norm * ope_table.at(x).at(col);
}
}
}
basic.at(x - 1) = y; // 换元
}