"""This cell defines the plot_candles function"""
def plot_candles(pricing, title=None,
volume_bars=False,
color_function=None,
overlays=None,
technicals=None,
technicals_titles=None):
""" Plots a candlestick chart using quantopian pricing data.
Author: Daniel Treiman
Args:
pricing: A pandas dataframe with columns ['open_price', 'close_price', 'high', 'low', 'volume']
title: An optional title for the chart
volume_bars: If True, plots volume bars
color_function: A function which, given a row index and price series, returns a candle color.
overlays: A list of additional data series to overlay on top of pricing. Must be the same length as pricing.
technicals: A list of additional data series to display as subplots.
technicals_titles: A list of titles to display for each technical indicator.
"""
def default_color(index, open_price, close_price, low, high):
return 'r' if open_price[index] > close_price[index] else 'g'
color_function = color_function or default_color
overlays = overlays or []
technicals = technicals or []
technicals_titles = technicals_titles or []
open_price = pricing['open_price']
close_price = pricing['close_price']
low = pricing['low']
high = pricing['high']
oc_min = pd.concat([open_price, close_price], axis=1).min(axis=1)
oc_max = pd.concat([open_price, close_price], axis=1).max(axis=1)
subplot_count = 1
if volume_bars:
subplot_count = 2
if technicals:
subplot_count += len(technicals)
if subplot_count == 1:
fig, ax1 = plt.subplots(1, 1)
else:
ratios = np.insert(np.full(subplot_count - 1, 1), 0, 3)
fig, subplots = plt.subplots(subplot_count, 1, sharex=True, gridspec_kw={'height_ratios': ratios})
ax1 = subplots[0]
if title:
ax1.set_title(title)
x = np.arange(len(pricing))
candle_colors = [color_function(i, open_price, close_price, low, high) for i in x]
candles = ax1.bar(x, oc_max-oc_min, bottom=oc_min, color=candle_colors, linewidth=0)
lines = ax1.vlines(x + 0.4, low, high, color=candle_colors, linewidth=1)
ax1.xaxis.grid(False)
ax1.xaxis.set_tick_params(which='major', length=3.0, direction='in', top='off')
# Assume minute frequency if first two bars are in the same day.
frequency = 'minute' if (pricing.index[1] - pricing.index[0]).days == 0 else 'day'
time_format = '%d-%m-%Y'
if frequency == 'minute':
time_format = '%H:%M'
# Set X axis tick labels.
plt.xticks(x, [date.strftime(time_format) for date in pricing.index], rotation='vertical')
for overlay in overlays:
ax1.plot(x, overlay)
# Plot volume bars if needed
if volume_bars:
ax2 = subplots[1]
volume = pricing['volume']
volume_scale = None
scaled_volume = volume
if volume.max() > 1000000:
volume_scale = 'M'
scaled_volume = volume / 1000000
elif volume.max() > 1000:
volume_scale = 'K'
scaled_volume = volume / 1000
ax2.bar(x, scaled_volume, color=candle_colors)
volume_title = 'Volume'
if volume_scale:
volume_title = 'Volume (%s)' % volume_scale
ax2.set_title(volume_title)
ax2.xaxis.grid(False)
# Plot additional technical indicators
for (i, technical) in enumerate(technicals):
ax = subplots[i - len(technicals)] # Technical indicator plots are shown last
ax.plot(x, technical)
if i < len(technicals_titles):
ax.set_title(technicals_titles[i])