import torch.nn as nn import torch.optim as optim class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim), nn.Sigmoid() # Output between 0 and 1 ) self.max_action = max_action def forward(self, state): return self.max_action * self.net(state) class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim + action_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1) ) def forward(self, state, action): if action.dim() == 1: # If action is 1D, reshape it to (batch_size, action_dim) action = action.unsqueeze(1) return self.net(torch.cat([state, action], dim=1))