Source code for cntk.contrib.deeprl.tests.agent_test

# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

import unittest

import cntk.contrib.deeprl.tests.spaces as spaces
from cntk.contrib.deeprl.agent.agent import AgentBaseClass


[docs]class FakeAgentBaseClass(AgentBaseClass): """Subclass AgentBaseClass for unittest."""
[docs] def start(self, state): pass
[docs] def step(self, reward, next_state): pass
[docs] def end(self, reward, next_state): pass
[docs] def save(self, filename): pass
[docs] def save_parameter_settings(self, filename): pass
[docs] def set_as_best_model(self): pass
def _choose_action(self, state): pass
[docs]class AgentBaseClassTest(unittest.TestCase): """Unit tests for AgentBaseClass."""
[docs] def test_init_unsupported_action_space(self): action_space = spaces.Box(0, 1, (1,)) observation_space = spaces.Discrete(3) self.assertRaises( ValueError, FakeAgentBaseClass, observation_space, action_space)
[docs] def test_init_unsupported_observation_space(self): action_space = spaces.Discrete(2) observation_space = spaces.Tuple( [spaces.Discrete(3), spaces.Discrete(3)]) self.assertRaises( ValueError, FakeAgentBaseClass, observation_space, action_space)
[docs] def test_init_discrete_observation_space(self): action_space = spaces.Discrete(2) observation_space = spaces.Discrete(3) sut = FakeAgentBaseClass(observation_space, action_space) self.assertEqual(sut._num_actions, 2) self.assertEqual(sut._num_states, 3) self.assertEqual(sut._shape_of_inputs, (3, )) self.assertTrue(sut._discrete_observation_space) self.assertIsNone(sut._space_discretizer) self.assertIsNone(sut._preprocessor)
[docs] def test_init_multibinary_observation_space(self): action_space = spaces.Discrete(2) observation_space = spaces.MultiBinary(3) sut = FakeAgentBaseClass(observation_space, action_space) self.assertEqual(sut._num_actions, 2) self.assertIsNone(sut._num_states) self.assertEqual(sut._shape_of_inputs, (3, )) self.assertFalse(sut._discrete_observation_space) self.assertIsNone(sut._space_discretizer) self.assertIsNone(sut._preprocessor)
[docs] def test_init_box_observation_space(self): action_space = spaces.Discrete(2) observation_space = spaces.Box(0, 1, (1,)) sut = FakeAgentBaseClass(observation_space, action_space) self.assertEqual(sut._num_actions, 2) self.assertIsNone(sut._num_states) self.assertEqual(sut._shape_of_inputs, (1, )) self.assertFalse(sut._discrete_observation_space) self.assertIsNone(sut._space_discretizer) self.assertIsNone(sut._preprocessor)