import random from collections import deque class ReplayBuffer: def __init__(self, capacity=10000): self.buffer = deque(maxlen=capacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) states, actions, rewards, next_states, dones = zip(*batch) return ( torch.tensor(states, dtype=torch.float32), torch.tensor(actions, dtype=torch.float32), torch.tensor(rewards, dtype=torch.float32).unsqueeze(1), torch.tensor(next_states, dtype=torch.float32), torch.tensor(dones, dtype=torch.float32).unsqueeze(1) ) def __len__(self): # Add this method return len(self.buffer)