@Barden9

Что делать если агент не использует ray perception sensor 3d?

У меня есть сцена с трассой и машиной, и я хочу чтобы машина научилась сама проходить трассу. Всё работает правильно, только машина странно обучается. Такое чувство, что она не использует ray perception sensor 3d. Также в сцене есть чекпоинты, если машина проходит их в правильном направлении, то она получает награду +0.01 , а если в неправильном то -0.01. Настройки агента: 64c26336184cd009858745.png
64c2634e02d49414070584.png Скрипт агента:
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Runtime.CompilerServices;
using UnityEngine.UI;
using TMPro;
using System;
using Random = UnityEngine.Random;
 
public class Car : Agent
{
    [SerializeField] private TrackCheckpoints trackCheckpoints;
    [SerializeField] private Transform spwanPosition;
 
    private MSVehicleControllerFree carDriver;
    private MSSceneControllerFree sceneController;
    public GameObject scene;
 
    public Text distance_text;
 
    double distanceTravelled = 0;
    Vector3 lastPosition;
    
    //в корутине даём награду за расстоянии которое проехала машина
    IEnumerator distance_reward()
    {
        while (true)
        {
            yield return new WaitForSeconds(1.0f);
            AddReward((float)(0.001 * distanceTravelled));
            distanceTravelled = 0;
        }
    }
 
    public void Awake()
    {
        carDriver = GetComponent<MSVehicleControllerFree>();
        lastPosition = transform.position;
        sceneController = scene.GetComponent<MSSceneControllerFree>();
        StartCoroutine(distance_reward());
    }
 
 
 
    private void FixedUpdate()
    {
        //штрафуем машину за низкую скорость
        if (sceneController.car_speed < 10) {
            AddReward(-0.00001f);
        }
 
        //изменяем дистацию которая проехала машина
        if (sceneController.car_speed >= 0) {
            distanceTravelled += Vector3.Distance(transform.position, lastPosition);
        }
        else if (sceneController.car_speed < 0) {
            distanceTravelled -= Vector3.Distance(transform.position, lastPosition);
        }
        lastPosition = transform.position;
        distanceTravelled = Math.Round(distanceTravelled, 2);
        distance_text.text = "Distance: " + distanceTravelled;
 
        //награждаем машину за скорость
        AddReward((float)(0.000002 * sceneController.car_speed));
    }
 
    public override void OnEpisodeBegin()
    {
        //спавн машины
        transform.position = spwanPosition.position + new Vector3 (Random.Range(-1f,+1f), 0, Random.Range(-1f, +1f));
        transform.forward = spwanPosition.forward;
        trackCheckpoints.ResetCheckpoints(transform);
        carDriver.GetComponent<Rigidbody>().velocity = Vector3.zero;
        
        //сбрасываем дистанцию которая прошла машина
        lastPosition = transform.position;
        distanceTravelled = 0;
        StartCoroutine(distance_reward());
    }
    public override void CollectObservations(VectorSensor sensor)
    {
        Vector3 checkpointForward = trackCheckpoints.GetNextCheckpoint(transform).transform.forward;
        float directionDot = Vector3.Dot(transform.forward, checkpointForward);
        
        //передаем на вход ИИ направление следующего чекпоинта, скорости машины, дистанции которую прошла в течении 1 секунды
        //sensor.AddObservation(directionDot);
        sensor.AddObservation(sceneController.car_speed);
        sensor.AddObservation((float)distanceTravelled);
    }
    public override void OnActionReceived(ActionBuffers actions)
    {
        float forwardAmount = 0f;
        float turnAmount = 0f;
 
        switch (actions.DiscreteActions[0])
        {
            case 0: forwardAmount = 0f; break;
            case 1: forwardAmount = +1f; break;
            case 2: forwardAmount = -1f; break;
        }
        switch (actions.DiscreteActions[1])
        {
            case 0: turnAmount = 0f; break;
            case 1: turnAmount = -1f; break;
            case 2: turnAmount = +1f; break;
        }
        sceneController.Update_Controls(forwardAmount, turnAmount);
        carDriver.GetComponent<MSVehicleControllerFree>().Update_Controls(forwardAmount, turnAmount);
        
        //штрафуем машину каждый шаг
        AddReward(-0.000007f);
    }
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int forwardAction = 0;
        if(Input.GetKey(KeyCode.Keypad8)) forwardAction = 1;
        if (Input.GetKey(KeyCode.Keypad5)) forwardAction = 2;
 
        int turnAction = 0;
        if (Input.GetKey(KeyCode.Keypad4)) turnAction = 1;
        if (Input.GetKey(KeyCode.Keypad6)) turnAction = 2;
 
        ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
        discreteActions[0] = forwardAction;
        discreteActions[1] = turnAction;
    }
    
    //даём отрицательную награду если машина врезалась в стену
    private void OnTriggerEnter(Collider other)
    {
        if (other.gameObject.TryGetComponent<Wall>(out Wall wall))
        {
            AddReward(-0.01f);
            EndEpisode();
        }
    }
}
  • Вопрос задан
  • 44 просмотра
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Похожие вопросы