Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set single candle color #451

Merged
merged 13 commits into from
Dec 14, 2021
11 changes: 11 additions & 0 deletions src/mplfinance/_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,17 @@ def _valid_mav(value, is_period=True):
return True
return False

def _colors_validator(value):
if not isinstance(value, list):
AurumnPegasus marked this conversation as resolved.
Show resolved Hide resolved
return False

for v in value:
if v:
if not (isinstance(v, dict) or isinstance(v, str)):
return False
DanielGoldfarb marked this conversation as resolved.
Show resolved Hide resolved

return True


def _hlines_validator(value):
if isinstance(value,dict):
Expand Down
144 changes: 113 additions & 31 deletions src/mplfinance/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from six.moves import zip

def _check_input(opens, closes, highs, lows):
def _check_input(opens, closes, highs, lows, colors=None):
DanielGoldfarb marked this conversation as resolved.
Show resolved Hide resolved
"""Checks that *opens*, *highs*, *lows* and *closes* have the same length.
NOTE: this code assumes if any value open, high, low, close is
missing (*-1*) they all are missing
Expand All @@ -46,6 +46,10 @@ def _check_input(opens, closes, highs, lows):
if not same_length:
raise ValueError('O,H,L,C must have the same length!')

if colors:
if len(opens) != len(colors):
raise ValueError('O,H,L,C and Colors must have the same length!')

o = np.where(np.isnan(opens))[0]
h = np.where(np.isnan(highs))[0]
l = np.where(np.isnan(lows))[0]
Expand Down Expand Up @@ -85,19 +89,19 @@ def _check_and_convert_xlim_configuration(data, config):
return xlim


def _construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style):
def _construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style,colors):
collections = None
if ptype == 'candle' or ptype == 'candlestick':
collections = _construct_candlestick_collections(xdates, opens, highs, lows, closes,
marketcolors=style['marketcolors'],config=config )
marketcolors=style['marketcolors'],config=config, colors=colors )

elif ptype =='hollow_and_filled':
collections = _construct_hollow_candlestick_collections(xdates, opens, highs, lows, closes,
marketcolors=style['marketcolors'],config=config )
marketcolors=style['marketcolors'],config=config, colors=colors )

elif ptype == 'ohlc' or ptype == 'bars' or ptype == 'ohlc_bars':
collections = _construct_ohlc_collections(xdates, opens, highs, lows, closes,
marketcolors=style['marketcolors'],config=config )
marketcolors=style['marketcolors'],config=config, colors=colors )
elif ptype == 'renko':
collections = _construct_renko_collections(
dates, highs, lows, volumes, config['renko_params'], closes, marketcolors=style['marketcolors'])
Expand Down Expand Up @@ -176,16 +180,45 @@ def coalesce_volume_dates(in_volumes, in_dates, indexes):
return volumes, dates


def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False):
if upcolor == downcolor:
return upcolor
cmap = {True : upcolor, False : downcolor}
if not use_prev_close:
return [ cmap[opn < cls] for opn,cls in zip(opens,closes) ]
def _updown_colors(upcolor,downcolor,opens,closes,use_prev_close=False,colors=None):
if not colors:
if upcolor == downcolor:
return upcolor
cmap = {True : upcolor, False : downcolor}
if not use_prev_close:
return [ cmap[opn < cls] for opn,cls in zip(opens,closes) ]
else:
first = cmap[opens[0] < closes[0]]
_list = [ cmap[pre < cls] for cls,pre in zip(closes[1:], closes) ]
return [first] + _list
else:
first = cmap[opens[0] < closes[0]]
_list = [ cmap[pre < cls] for cls,pre in zip(closes[1:], closes) ]
return [first] + _list
cmap = {True: 'up', False: 'down'}
default = {'up': upcolor, 'down': downcolor}
custom = []
if not use_prev_close:
for i in range(len(opens)):
opn = opens[i]
cls = closes[i]
if colors[i]:
custom.append(colors[i][cmap[opn < cls]])
else:
custom.append(default[cmap[opn < cls]])
else:
if color[0]:
custom.append(colors[0][cmap[opens[0] < closes[0]]])
else:
custom.append(default[cmap[opens[0] < closes[0]]])

for i in range(len(closes) - 1):
pre = closes[1:][i]
cls = closes[i]
if colors[i]:
custom.append(colors[i][cmap[pre < cls]])
else:
custom.append(default[cmap[pre < cls]])

return custom



def _updownhollow_colors(upcolor,downcolor,hollowcolor,opens,closes):
Expand Down Expand Up @@ -447,7 +480,7 @@ def _valid_lines_kwargs():
return vkwargs


def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
"""Represent the time, open, high, low, close as a vertical line
ranging from low to high. The left tick is the open and the right
tick is the close.
Expand All @@ -472,8 +505,8 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
ret : list
a list or tuple of matplotlib collections to be added to the axes
"""

_check_input(opens, highs, lows, closes)
_check_input(opens, highs, lows, closes, colors)

if marketcolors is None:
mktcolors = _get_mpfstyle('classic')['marketcolors']['ohlc']
Expand All @@ -497,13 +530,25 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
# we'll translate these to the date, close location
closeSegments = [((dt, close), (dt+ticksize, close)) for dt, close in zip(dates, closes)]

if mktcolors['up'] == mktcolors['down']:
colors = mktcolors['up']
else:
colorup = mcolors.to_rgba(mktcolors['up'])
colordown = mcolors.to_rgba(mktcolors['down'])
colord = {True: colorup, False: colordown}
colors = [colord[open < close] for open, close in zip(opens, closes)]
bar_c = None
if colors:
bar_c = []
for color in colors:
if color:
bar_up = color['ohlc']['up']
bar_down = color['ohlc']['down']
if bar_up == 'k':
bar_up = mktcolors['up']
if bar_down == 'k':
bar_down = mktcolors['down']
DanielGoldfarb marked this conversation as resolved.
Show resolved Hide resolved

bar_c.append({'up': mcolors.to_rgba(bar_up, 1), 'down': mcolors.to_rgba(bar_down, 1)})
else:
bar_c.append(None)

uc = mcolors.to_rgba(mktcolors['up'])
dc = mcolors.to_rgba(mktcolors['down'])
colors = _updown_colors(uc, dc, opens, closes, colors=bar_c)

lw = config['_width_config']['ohlc_linewidth']

Expand All @@ -525,7 +570,7 @@ def _construct_ohlc_collections(dates, opens, highs, lows, closes, marketcolors=
return [rangeCollection, openCollection, closeCollection]


def _construct_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
def _construct_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
"""Represent the open, close as a bar line and high low range as a
vertical line.

Expand All @@ -552,8 +597,8 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
ret : list
(lineCollection, barCollection)
"""
_check_input(opens, highs, lows, closes)

_check_input(opens, highs, lows, closes, colors)

if marketcolors is None:
marketcolors = _get_mpfstyle('classic')['marketcolors']
Expand Down Expand Up @@ -581,17 +626,54 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market

alpha = marketcolors['alpha']

candle_c = None
wick_c = None
edge_c = None
if colors:
candle_c = []
wick_c = []
edge_c = []
for color in colors:
if color:
candle_up = color['candle']['up']
candle_down = color['candle']['down']
edge_up = color['edge']['up']
edge_down = color['edge']['down']
wick_up = color['wick']['up']
wick_down = color['wick']['down']
if candle_up == 'w':
candle_up = marketcolors['candle']['up']
if candle_down == 'k':
candle_down = marketcolors['candle']['down']
if edge_up == 'k':
edge_up = candle_up
if edge_down == 'k':
edge_down = candle_down
if wick_up == 'k':
wick_up = candle_up
if wick_down == 'k':
wick_down = candle_down
DanielGoldfarb marked this conversation as resolved.
Show resolved Hide resolved

candle_c.append({'up': mcolors.to_rgba(candle_up, alpha), 'down': mcolors.to_rgba(candle_down, alpha)})
edge_c.append({'up': mcolors.to_rgba(edge_up, 1), 'down': mcolors.to_rgba(edge_down, 1)})
wick_c.append({'up': mcolors.to_rgba(wick_up, 1), 'down': mcolors.to_rgba(wick_down, 1)})

else:
candle_c.append(None)
wick_c.append(None)
edge_c.append(None)

uc = mcolors.to_rgba(marketcolors['candle'][ 'up' ], alpha)
dc = mcolors.to_rgba(marketcolors['candle']['down'], alpha)
colors = _updown_colors(uc, dc, opens, closes)
colors = _updown_colors(uc, dc, opens, closes, colors=candle_c)

uc = mcolors.to_rgba(marketcolors['edge'][ 'up' ], 1.0)
dc = mcolors.to_rgba(marketcolors['edge']['down'], 1.0)
edgecolor = _updown_colors(uc, dc, opens, closes)
edgecolor = _updown_colors(uc, dc, opens, closes, colors=edge_c)

uc = mcolors.to_rgba(marketcolors['wick'][ 'up' ], 1.0)
dc = mcolors.to_rgba(marketcolors['wick']['down'], 1.0)
wickcolor = _updown_colors(uc, dc, opens, closes)
wickcolor = _updown_colors(uc, dc, opens, closes, colors=wick_c)

lw = config['_width_config']['candle_linewidth']

Expand All @@ -609,7 +691,7 @@ def _construct_candlestick_collections(dates, opens, highs, lows, closes, market
return [rangeCollection, barCollection]


def _construct_hollow_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None):
def _construct_hollow_candlestick_collections(dates, opens, highs, lows, closes, marketcolors=None, config=None, colors=None):
"""Represent today's open to close as a "bar" line (candle body)
and high low range as a vertical line (candle wick)

Expand Down
19 changes: 16 additions & 3 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from mplfinance._arg_validators import _scale_padding_validator, _yscale_validator
from mplfinance._arg_validators import _valid_panel_id, _check_for_external_axes
from mplfinance._arg_validators import _xlim_validator
from mplfinance._arg_validators import _colors_validator

from mplfinance._panels import _build_panels
from mplfinance._panels import _set_ticks_on_bottom_panel_only
Expand All @@ -49,6 +50,8 @@
from mplfinance._helpers import _num_or_seq_of_num
from mplfinance._helpers import _adjust_color_brightness

from mplfinance._styles import make_marketcolors

VALID_PMOVE_TYPES = ['renko', 'pnf']

DEFAULT_FIGRATIO = (8.00,5.75)
Expand Down Expand Up @@ -125,6 +128,9 @@ def _valid_plot_kwargs():

'marketcolors' : { 'Default' : None, # use 'style' for default, instead.
'Validator' : lambda value: isinstance(value,dict) },

'colors' : { 'Default' : None, # use default style instead.
'Validator' : lambda value: _colors_validator(value) },
DanielGoldfarb marked this conversation as resolved.
Show resolved Hide resolved

'no_xgaps' : { 'Default' : True, # None means follow default logic below:
'Validator' : lambda value: _warn_no_xgaps_deprecated(value) },
Expand Down Expand Up @@ -391,14 +397,21 @@ def plot( data, **kwargs ):
rwc = config['return_width_config']
if isinstance(rwc,dict) and len(rwc)==0:
config['return_width_config'].update(config['_width_config'])


if config['colors']:
colors = config['colors']
for c in range(len(colors)):
if isinstance(colors[c], str):
config['colors'][c] = make_marketcolors(up=colors[c], down=colors[c], edge=colors[c], wick=colors[c], ohlc=colors[c], volume=colors[c])
else:
config['colors'] = None

collections = None
if ptype == 'line':
lw = config['_width_config']['line_width']
axA1.plot(xdates, closes, color=config['linecolor'], linewidth=lw)
else:
collections =_construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style)
collections =_construct_mpf_collections(ptype,dates,xdates,opens,highs,lows,closes,volumes,config,style,config['colors'])

if ptype in VALID_PMOVE_TYPES:
collections, calculated_values = collections
Expand Down Expand Up @@ -858,7 +871,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
if not isinstance(apdata,pd.DataFrame):
raise TypeError('addplot type "'+aptype+'" MUST be accompanied by addplot data of type `pd.DataFrame`')
d,o,h,l,c,v = _check_and_prepare_data(apdata,config)
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'])
collections = _construct_mpf_collections(aptype,d,xdates,o,h,l,c,v,config,config['style'],config['colors'])

if not external_axes_mode:
lo = math.log(max(math.fabs(np.nanmin(l)),1e-7),10) - 0.5
Expand Down