from gekko import GEKKO
import numpy as np
import matplotlib.pyplot as plt

m = GEKKO(remote=False)
t = np.linspace(0, 1, 101)
m.time = t

m.options.SOLVER = 3
m.options.IMODE = 6 
m.options.NODES = 2 
m.options.CV_TYPE = 1
m.options.MAX_ITER = 1000

p = m.SV(10)         # production (constant)
s = m.Var(0.1, lb=0) # storage inventory
stored = m.SV()      # store energy rate
recovery = m.SV()    # recover energy rate
vx = m.SV(lb=0)      # recover slack variable
vy = m.SV(lb=0)      # store slack variable

eps = 0.85 # Storage efficiency

d = m.MV((-20*np.sin(np.pi*t/12*24)+100)/10)
d_h = m.MV((15*np.cos(np.pi*t/12*24)+150)/10)

p_h_initial = m.Intermediate(p*1.5)

p_h = m.SV(p_h_initial)
s_h = m.Var(0.5,lb=0)
stored_h = m.SV()
recovery_h = m.SV()

#renewable energy source
renewable = (20*np.cos(np.pi*t/6*24)+20)/10 
center = np.ones(len(t))
num = len(t)
center[0:int(num/4)] = 0
center[-int(num/4):] = 0
renewable *= center
r = m.Param(renewable)

r1 = m.MV(ub=3,lb=-3)
r1.STATUS=1

m.periodic(s_h)

zx = m.SV(lb=0)
zy = m.SV(lb=0)

eps_h = 0.8 # heat storage efficiency

net = m.Intermediate(d-r)
m.Equations([p + r + recovery/eps - stored >= d,
             p + r - d == vx - vy,
             stored == p + r - d + vy,
             recovery == d - p - r + vx,
             s.dt() == stored - recovery/eps,
             p.dt() == r1,
             stored * recovery <= 0,
             p_h + recovery_h/eps_h - stored_h >= d_h,
             p_h - d_h == zx - zy,
             stored_h == p_h - d_h + zy,
             recovery_h == d_h - p_h + zx,
             s_h.dt() == stored_h - recovery_h/eps_h,
             stored_h * recovery_h <= 0,
             p_h == 1.5 * p])
m.Minimize(p)
m.solve()

# Plot solution
fig, axes = plt.subplots(5, 1, figsize=(5, 5.1), sharex=True)
axes = axes.ravel()

ax = axes[0]
ax.plot(t, d, 'r-', label='Demand 1 ($d_1$)')
ax.plot(t, p,'b:', label='Production 1 ($g_1$)',lw=2)
ax.plot(t, net, 'k--', label='Net ($d_1-R_1$)')

ax = axes[1]
ax.plot(t,r, 'b-',label='Source 1 ($R_1$)')
ax.plot(t,r1, 'k--', label='Ramp Rate ($r$)')

ax = axes[2]
ax.plot(t,s, 'k-', label='Storage 1 ($e_1$)')
ax.plot(t,stored,'g--',label='Stored ($e_{\text{in},1}$)')
ax.plot(t,recovery,'b:',label='Recovered ($e_{\text{out},1}$)',lw=2)

ax = axes[3]
ax.plot(t,d_h, 'r-', label='Demand 2 ($d_2$)')
ax.plot(t[1:], p_h.value[1:],'b:',\
        label='Production 2 ($g_2$)',lw=2)

ax = axes[4]
ax.plot(t,s_h, 'k-', label='Storage 2 ($e_2$)')
ax.plot(t,stored_h,'g--',label='Stored ($e_{\text{in},2}$)')
ax.plot(t[1:],recovery_h.value[1:],'b:',\
        label='Recovered ($e_{\text{out},2}$)',lw=2)
ax.set_xlabel('Time')

for ax in axes:
    ax.legend(loc='center left',\
              bbox_to_anchor=(1,0.5),frameon=False)
    ax.grid()
    ax.set_xlim(0, 1)
plt.tight_layout()
plt.savefig('grid_energy6.png', dpi=600,\
            bbox_inches = 'tight')
plt.show()