# Actor Network: maps state -> action (within [-max_action, max_action]) class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.max_action = max_action # Simple 2-layer MLP self.net = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim) ) def forward(self, state): # Output raw action, then scale to range [-max_action, max_action] using tanh raw_action = self.net(state) # bound output action between -1 and 1 via tanh, then scale action = self.max_action * torch.tanh(raw_action) return action # Critic Network: maps (state, action) -> Q-value class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() # Q-network takes state and action concatenated self.net = nn.Sequential( nn.Linear(state_dim + action_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, state, action): # Ensure state and action are concatenated as vectors if action.dim() == 1: action = action.unsqueeze(1) x = torch.cat([state, action], dim=1) Q = self.net(x) return Q