#
#  Predator-prey program using 
#  Lotka-Volteera difference equations:
#   H= number of hare, L = number of lynx
#   H_br,H_dr = Hare birth/death rate 
#   L_br,L_dr = Lynx birth/death rate       
#   H[t+1] = H[t] + dt*( H_br*H[t] - H_dr*H[t]* L[t])
#   L[t+1] = L[t] + dt*( L_br*H[t]*L[t] - L_dr* L[t])   
# 
import numpy as np
import matplotlib.pyplot as plt
def compute_populations(hare_init,lynx_init,dt, t):
    """  
    Computes lynx and hare populations at each time step
    based on Lotka-Volterra difference equations
    """
    global hare_br,hare_dr, lynx_br, lynx_dr  # Key variables
    hare_pop = np.zeros(t)  # Array for hare population at each step
    lynx_pop = np.zeros(t)  # Array for lynx population at each step
    hare_pop[0] = hare_init
    lynx_pop[0] = lynx_init
    for i in range(t-1): 
######################################################################
#
#  Your code goes HERE: Update each time step for the hare and lynx
#  populations with Lotka-Volterra equations 
#
######################################################################
    return lynx_pop, hare_pop 
#  Key model variables:
hare_br   = .5               # hare birthrate
hare_dr   = .02              # hare deathrate
lynx_br   = .25 * hare_dr    # lynx birthrate
lynx_dr   = .75              # lynx deathrate 
hare_init = 500              # Initial number of hare
lynx_init = 50               # Initial number of lynx
#  Time step variables 
dt        = .1               # Time step (fraction of month)  
months    = 48               # Total number of months to model  
t = int(months/dt)           # Total time steps in model
#
# Compute arrays for hare and lynx populations at each time step
#
lynx, hare = compute_populations(hare_init,lynx_init, dt, t)
#
#  Graph lynx and hare populations versus time
#
x = np.array(range(t))
plt.plot(x, hare,'r-', label = 'Hare')
plt.plot(x, lynx,'b-', label = 'Lynx')
plt.legend(loc= 'upper left')
plt.title(' Predator_prey model (lynx vs hare)' )
plt.grid(True) 
plt.savefig("lynx_hare.pdf") 
plt.show()

