import torch import torch.nn as nn class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super().__init__() self.max_action = max_action self.net = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim) ) def forward(self, state): return self.max_action * torch.tanh(self.net(state)) class Critic(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim + action_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, state, action): return self.net(torch.cat([state, action], dim=1))