运行环境适配
This commit is contained in:
25
estimator.py
25
estimator.py
@@ -105,15 +105,14 @@ class Estimator:
|
||||
class NeuralEstimator(Estimator):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.net = Net(input_size=self.data_mgr.get_feature(), output_size=1).to(device)
|
||||
self.net = Net(input_size=self.data_mgr.get_feature(), output_size=1).to(self.device)
|
||||
self.net_file = 'model/net_model.pth'
|
||||
if os.path.exists(self.net_file):
|
||||
try:
|
||||
self.net.load_state_dict(torch.load(self.net_file))
|
||||
except:
|
||||
warnings.warn('the parameters of neural net model load failed', UserWarning)
|
||||
try:
|
||||
self.net.load_state_dict(torch.load(self.net_file, map_location=self.device))
|
||||
except:
|
||||
warnings.warn('the parameters of neural net model load failed', UserWarning)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.net.modules():
|
||||
@@ -128,8 +127,8 @@ class NeuralEstimator(Estimator):
|
||||
x_train = np.array(data_mgr.neural_encode(data[0][::data_mgr.get_update_round()]))
|
||||
y_train = np.array(data[1][::data_mgr.get_update_round()])
|
||||
|
||||
x_train = torch.from_numpy(x_train.reshape((-1, np.shape(x_train)[1]))).float().to(device)
|
||||
y_train = torch.from_numpy(y_train.reshape((-1, 1))).float().to(device)
|
||||
x_train = torch.from_numpy(x_train.reshape((-1, np.shape(x_train)[1]))).float().to(self.device)
|
||||
y_train = torch.from_numpy(y_train.reshape((-1, 1))).float().to(self.device)
|
||||
|
||||
optimizer = torch.optim.Adam(self.net.parameters(), lr=params.lr)
|
||||
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)
|
||||
@@ -161,7 +160,7 @@ class NeuralEstimator(Estimator):
|
||||
data = data_mgr.loader('opt/' + params.test_file)
|
||||
|
||||
x_test, y_test = np.array(data_mgr.neural_encode(data[0])), np.array(data[1])
|
||||
x_test = torch.from_numpy(x_test.reshape((-1, np.shape(x_test)[1]))).float().to(device)
|
||||
x_test = torch.from_numpy(x_test.reshape((-1, np.shape(x_test)[1]))).float().to(self.device)
|
||||
|
||||
self.net.eval()
|
||||
with torch.no_grad():
|
||||
@@ -171,7 +170,7 @@ class NeuralEstimator(Estimator):
|
||||
def predict(self, cp_points, cp_nozzle, board_width=None, board_height=None):
|
||||
assert board_width is not None and board_height is not None
|
||||
encoding = np.array(self.data_mgr.encode(cp_points, cp_nozzle, board_width, board_height))
|
||||
encoding = torch.from_numpy(encoding.reshape((-1, np.shape(encoding)[0]))).float().to("cuda")
|
||||
encoding = torch.from_numpy(encoding.reshape((-1, np.shape(encoding)[0]))).float().to(self.device)
|
||||
return self.net(encoding)[0, 0].item()
|
||||
|
||||
|
||||
@@ -184,6 +183,8 @@ class HeuristicEstimator(Estimator):
|
||||
if os.path.exists(self.pickle_file):
|
||||
with open(self.pickle_file, 'rb') as f:
|
||||
self.lr = pickle.load(f)
|
||||
else:
|
||||
warnings.warn('the parameters of heuristic lr model load failed', UserWarning)
|
||||
|
||||
def training(self, params):
|
||||
data = data_mgr.loader('opt/' + params.train_file)
|
||||
@@ -304,6 +305,8 @@ class ReconfigEstimator(Estimator):
|
||||
if os.path.exists(self.pickle_file):
|
||||
with open(self.pickle_file, 'rb') as f:
|
||||
self.lr = pickle.load(f)
|
||||
else:
|
||||
warnings.warn('the parameters of reconfig model load failed', UserWarning)
|
||||
|
||||
def training(self, params):
|
||||
data = data_mgr.loader('opt/' + params.train_file)
|
||||
|
||||
Reference in New Issue
Block a user