怎么用Java训练出一只不死鸟
本篇内容介绍了“怎么用Java训练出一只不死鸟”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!
网站建设哪家好,找创新互联!专注于网页设计、网站建设、微信开发、微信小程序开发、集团企业网站建设等服务项目。为回馈新老客户创新互联还提供了拜泉免费建站欢迎大家使用!
增强学习(RL)的架构
在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。
CNN 训练简述
CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height))
数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2)
的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。
训练数据
在小鸟采取动作后,我们会得到 preObservation and currentObservation
即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal
组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。
public void step(NDList action, boolean training) { if (action.singletonOrThrow().getInt(1) == 1) { bird.birdFlap(); } stepFrame(); NDList preObservation = currentObservation; currentObservation = createObservation(currentImg); FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(), preObservation, currentObservation, action, currentReward, currentTerminal); if (training) { replayBuffer.addStep(step); } if (gameState == GAME_OVER) { restartGame(); } }
训练的三个周期
训练分为 3 个不同的周期以更好地生成训练数据:
Observe(观察) 周期:随机产生训练数据
Explore (探索) 周期:随机与推理动作结合更新训练数据
Training (训练) 周期:推理动作主导产生新数据
通过这种训练模式,我们可以更好的达到预期效果。
处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。
public NDList chooseAction(RlEnv env, boolean training) { if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) { return env.getActionSpace().randomAction(); } else return baseAgent.chooseAction(env, training); }
训练逻辑
首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:
NDList QReward = trainer.forward(preInput); NDList Q = new NDList(QReward.singletonOrThrow() .mul(actionInput.singletonOrThrow()) .sum(new int[]{1}));
postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:
// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。 NDList targetQReward = trainer.forward(postInput); NDArray[] targetQValue = new NDArray[batchSteps.length]; for (int i = 0; i < batchSteps.length; i++) { if (batchSteps[i].isTerminal()) { targetQValue[i] = batchSteps[i].getReward(); } else { targetQValue[i] = targetQReward.singletonOrThrow().get(i) .max() .mul(rewardDiscount) .add(rewardInput.singletonOrThrow().get(i)); } } NDList targetQBatch = new NDList(); Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value))); NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。
卷积神经网络模型(CNN)
我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。
layer | input shape | output shape |
---|---|---|
conv2d | (batchSize, 4, 80, 80) | (batchSize,4,20,20) |
conv2d | (batchSize, 4, 20 ,20) | (batchSize, 32, 9, 9) |
conv2d | (batchSize, 32, 9, 9) | (batchSize, 64, 7, 7) |
linear | (batchSize, 3136) | (batchSize, 512) |
linear | (batchSize, 512) | (batchSize, 2) |
训练过程
DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。
实现 RlAgent 接口即可构建一个可以进行训练的智能体。
在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。
创建 ReplayBuffer 可以存储并动态更新训练数据。
在实现这些接口后,只需要调用 step方法:
RlEnv.step(action, training);
这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment
方法中调用 step 方法,然后只需要重复执行 runEnvironment
方法,即可不断地生成用于训练的数据。
public Step[] runEnvironment(RlAgent agent, boolean training) { // run the game NDList action = agent.chooseAction(this, training); step(action, training); if (training) { batchSteps = this.getBatch(); } return batchSteps; }
我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。
在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。
神经网络输入预处理
首先将图像大小 resize 成 80x80
并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。
public static NDArray imgPreprocess(BufferedImage observation) { return NDImageUtils.toTensor( NDImageUtils.resize( ImageFactory.getInstance().fromImage(observation) .toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE) ,80,80)); }
然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。
public NDList createObservation(BufferedImage currentImg) { NDArray observation = GameUtil.imgPreprocess(currentImg); if (imgQueue.isEmpty()) { for (int i = 0; i < 4; i++) { imgQueue.offer(observation); } return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1)); } else { imgQueue.remove(); imgQueue.offer(observation); NDArray[] buf = new NDArray[4]; int i = 0; for (NDArray nd : imgQueue) { buf[i++] = nd; } return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1)); } }
一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。
List> callables = new ArrayList<>(numOfThreads); callables.add(new GeneratorCallable(game, agent, training)); if(training) { callables.add(new TrainerCallable(model, agent)); }
“怎么用Java训练出一只不死鸟”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注创新互联网站,小编将为大家输出更多高质量的实用文章!
标题名称:怎么用Java训练出一只不死鸟
当前路径:http://scyanting.com/article/iiocjd.html