修改生成数据方式和网络训练方式

This commit is contained in:
2024-04-06 13:44:05 +08:00
parent bae7e4e2c3
commit 6fa1f53f69
8 changed files with 194 additions and 96 deletions

View File

@ -283,7 +283,6 @@ def assemblyline_optimizer_genetic(pcb_data, component_data, machine_number):
# the number of generation: 500
crossover_rate, mutation_rate = 0.8, 0.1
population_size, n_generations = 200, 500
# population_size, n_generations = 30, 50
# the number of placement points, the number of available feeders, and nozzle type of component respectively
component_points, component_feeders, component_nozzle = defaultdict(int), defaultdict(int), defaultdict(str)
@ -300,7 +299,7 @@ def assemblyline_optimizer_genetic(pcb_data, component_data, machine_number):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net(input_size=data_mgr.get_feature(), output_size=1).to(device)
net.load_state_dict(torch.load('model_state.pth'))
net.load_state_dict(torch.load('model/net_model.pth'))
# optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
# optimizer.load_state_dict(torch.load('optimizer_state.pth'))
@ -361,6 +360,7 @@ def assemblyline_optimizer_genetic(pcb_data, component_data, machine_number):
best_individual = population[np.argmax(pop_val)]
val, assignment_result = cal_individual_val(component_points, component_feeders, component_nozzle, machine_number,
best_individual, data_mgr, net)
print('final value: ', val)
# available feeder check
for part_index, data in component_data.iterrows():