ML-Agent를 이용해서 FlappyBird를 학습시켜 보았다.

 

AddReward (보상) 설정은 파이브(벽) 또는 땅에 닿으면 -1 이 되고

 

파이프 중심을 기준으로 일자로 투명벽(reward wall)이 있는데, 

 

투명벽을 지나치면  +1이 되는 방식이다.

 

그리고 10점이 되면 End Episode가 되도록 설정했다.

 

환경 설정

 

학습 과정 1

 

학습 과정 2

 

 

학습 결과 및 해당 모델로 실행

 

 

 

using System.Collections;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine;
using static UnityEngine.RuleTile.TilingRuleOutput;
using UnityEngine.SocialPlatforms.Impl;
using TMPro;

public class PlayerController : Agent
{
    public float jumpForce = 2f;
    private Rigidbody2D rb;
    public bool isStarted = false;
    public bool isLose = false;
    public TMP_Text scoreText;
    private int score;
    bool jumpRequested = false;


    public BackgroundLoop[] bg;

    private void Start()
    {
        isStarted = true;
    }
    void Update()
    {
        if (Input.GetKeyDown(KeyCode.Space) || Input.GetMouseButtonDown(0))
        {
            if (!isStarted)
            {
                isStarted = true;
            }
            else
            {
                jumpRequested = true;
            }
        }
    }

    public override void Initialize()
    {
        rb = GetComponent<Rigidbody2D>();
        score = 0;
        UpdateScoreText();
        rb.gravityScale = 0;
    }

    public override void OnEpisodeBegin()
    {
        rb.velocity = Vector3.zero;
        rb.gravityScale = 0;
        score = 0;
        UpdateScoreText();
        transform.localPosition = new Vector3(0, 0, 0);


        foreach (var b in bg)
        {
            b.transform.position = b.startPos;
        }

    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(transform.localPosition.y);
        sensor.AddObservation(rb.velocity.y);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActionsOut = actionsOut.DiscreteActions;
        discreteActionsOut[0] = 0;

        if (jumpRequested)
        {
            discreteActionsOut[0] = 1;
            jumpRequested = false;
        }
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        int jumpAction = actions.DiscreteActions[0];
        if (jumpAction == 1 && isStarted)
        {
            rb.gravityScale = 1;
            rb.velocity = Vector3.zero;
            rb.AddForce(new Vector2(0, jumpForce),ForceMode2D.Impulse);

            // 점수가 10점이 되면 에피소드 종료
            if (score >= 10)
            {
                EndEpisode();
            }
        }
    }




    private void OnTriggerEnter2D(Collider2D collision)
    {
        if (collision.gameObject.CompareTag("ground") || collision.gameObject.CompareTag("wall"))
        {
            rb.gravityScale = 0;
            rb.velocity = Vector3.zero;
            isLose = true;
            AddReward(-1f);
            score = -1;
            UpdateScoreText();


            foreach (var b in bg)
            {
                b.transform.position = b.startPos;
            }

            EndEpisode();
        }

        if (collision.CompareTag("reward"))
        {
            AddReward(1f);
            score++;
            UpdateScoreText();
        }
    }

    private void UpdateScoreText()
    {
        if (scoreText != null)
            scoreText.text = "Score: " + score.ToString();
    }
}

PlayerController.cs

 

using System.Collections;
using UnityEngine;

public class BackgroundLoop : MonoBehaviour
{
    private float width; // 배경의 가로 길이
    public Transform companionBackground; // 동반 배경 오브젝트의 Transform
    public float speed = 10f; // 이동 속도
    public PlayerController player;
    public Vector2 startPos;


    private void Awake()
    {
        // 가로 길이를 측정하는 처리
        BoxCollider2D boxCollider = GetComponent<BoxCollider2D>();
        width = boxCollider.size.x;
        this.transform.position = startPos;
        StartCoroutine("CoMove");
        //StopCoroutine("CoStop");
    }


    // 위치를 리셋하는 메서드
    private void Reposition()
    {
        // 동반 배경 오브젝트의 오른쪽에 위치하도록 설정

        Vector2 offset = new Vector2(width * 2, 0);
        transform.position = (Vector2)companionBackground.position + offset;
    }

    public void RestartBackground()
    {
        StopAllCoroutines(); // 모든 코루틴을 중지합니다.
        this.transform.position = startPos;
        StartCoroutine(CoMove()); // 배경 이동 코루틴을 다시 시작합니다.
    }

    IEnumerator CoMove()
    {
        while (true)
        {
            transform.Translate(Vector3.left * speed * Time.deltaTime);
            if (transform.position.x <= -width*3)
            {
                Reposition();
            }
            yield return null;
        }
    }




    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.gameObject.CompareTag("Player"))
        {
            //StopCoroutine(CoMove());
            RestartBackground();
        }
    }

}

BackgroundLoop.cs

반응형

+ Recent posts