Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple plots #232

Merged
merged 13 commits into from
Aug 1, 2023
125 changes: 91 additions & 34 deletions package/scripts/prmon_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,64 @@ def get_multiplier(label, unit):
return MULTIPLIERS[ALLOWEDUNITS[label][0].upper()] / MULTIPLIERS[unit]


# Function for checking the input file exists
def check_input_file(file):
if not os.path.exists(file):
print(f"ERROR:: Input file {file} does not exist")
sys.exit(-1)


# Function for loading the data
def load_data(file):
data = pd.read_csv(file, sep="\t")
data["Time"] = pd.to_datetime(data["Time"], unit="s")
return data


# Function for checking whether the variables are in the data
def check_variables(data, var, ylist):
if var not in list(data):
print(f"ERROR:: Variable {var} is not available in one of the data sets")
sys.exit(-1)
for carg in ylist:
if carg not in list(data):
print(f"ERROR:: Variable {carg} is not available in one of the data sets")


# This function creates the final data set of y-values
def make_list(ylist, data, args, xmult, ymult, xlabel):
ydlist = []
for carg in ylist:
if args.diff:
num = np.array(data[carg].diff()) * ymult
denom = np.array(data[xlabel].diff()) * xmult
ratio = np.where(denom != 0, num / denom, np.nan)
ydlist.append(ratio)
else:
ydlist.append(np.array(data[carg]) * ymult)
return ydlist


# Graph plotting functions
def draw_stacked_graph(xdata, ydlist, ylist):
ydata = np.vstack(ydlist)
plt.stackplot(
xdata, ydata, lw=2, labels=[LEGENDNAMES[val] for val in ylist], alpha=0.6
)


def draw_line_graph(xdata, ydlist, ylist, sty, inputs, count):
# This is a list of the matplotlib default colours
colours = plt.rcParams["axes.prop_cycle"].by_key()["color"]
for cidx, cdata in enumerate(ydlist):
if len(inputs) == 1:
lbl = LEGENDNAMES[ylist[cidx]]
plt.plot(xdata, cdata, lw=2, label=lbl, color=colours[cidx], linestyle=sty)
else:
lbl = f"{LEGENDNAMES[ylist[cidx]]} ({inputs[count]})"
plt.plot(xdata, cdata, lw=2, label=lbl, color=colours[cidx], linestyle=sty)


def main():
"""prmon plotting main function"""

Expand All @@ -134,7 +192,8 @@ def main():
"--input",
type=str,
default="prmon.txt",
help="PrMon TXT output that will be used as input",
help="PrMon TXT output(s) that will be used as input(s)"
" (comma separated list is accepted)",
)
parser.add_argument(
"--output",
Expand All @@ -160,7 +219,7 @@ def main():
type=str,
default=default_yvar,
help="name(s) of the variable(s) to be plotted in the y-axis"
" (comma seperated list is accepted)",
" (comma separated list is accepted)",
)
parser.add_argument(
"--yunit",
Expand Down Expand Up @@ -193,24 +252,14 @@ def main():
parser.set_defaults(diff=False)
args = parser.parse_args()

# Check the input file exists
if not os.path.exists(args.input):
print(f"ERROR:: Input file {args.input} does not exists")
sys.exit(-1)

# Load the data
data = pd.read_csv(args.input, sep="\t")
data["Time"] = pd.to_datetime(data["Time"], unit="s")

# Check the variables are in data
if args.xvar not in list(data):
print(f"ERROR:: Variable {args.xvar} is not available in data")
sys.exit(-1)
inputs = args.input.split(",")
ylist = args.yvar.split(",")
for carg in ylist:
if carg not in list(data):
print(f"ERROR:: Variable {carg} is not available in data")
sys.exit(-1)
data = []

for i in range(len(inputs)):
check_input_file(inputs[i])
data.append(load_data(inputs[i]))
check_variables(data[i], args.xvar, ylist)

# Check the consistency of variables and units
# If they don't match, reset the units to defaults
Expand Down Expand Up @@ -257,24 +306,31 @@ def main():

# Here comes the figure and data extraction
fig, ax1 = plt.subplots()
xdata = np.array(data[xlabel]) * xmultiplier

xdata = []
ydlist = []
for carg in ylist:
if args.diff:
num = np.array(data[carg].diff()) * ymultiplier
denom = np.array(data[xlabel].diff()) * xmultiplier
ratio = np.where(denom != 0, num / denom, np.nan)
ydlist.append(ratio)
else:
ydlist.append(np.array(data[carg]) * ymultiplier)

for i in range(len(data)):
xdata.append(np.array(data[i][xlabel] * xmultiplier))
ydlist.append(make_list(ylist, data[i], args, xmultiplier, ymultiplier, xlabel))

# Plot the graphs
line_styles = list(mpl.lines.lineStyles.keys())

if args.stacked:
ydata = np.vstack(ydlist)
plt.stackplot(
xdata, ydata, lw=2, labels=[LEGENDNAMES[val] for val in ylist], alpha=0.6
)
if len(inputs) == 1:
for i in range(len(xdata)):
draw_stacked_graph(xdata[i], ydlist[i], ylist)
else:
print("ERROR:: Stacked graphs are not supported for more than one data set")
sys.exit(-1)
else:
for cidx, cdata in enumerate(ydlist):
plt.plot(xdata, cdata, lw=2, label=LEGENDNAMES[ylist[cidx]])
for i in range(len(xdata)):
draw_line_graph(
xdata[i], ydlist[i], ylist, line_styles[i % len(line_styles)], inputs, i
)

# Create the key
plt.legend(loc=0)
if "Time" in xlabel:
formatter = mpl.dates.DateFormatter("%H:%M:%S")
Expand All @@ -290,6 +346,7 @@ def main():
else:
fylabel = get_axis_label(ylist[0])
fyunit = args.yunit

plt.title("Plot of {} vs {}".format(fxlabel, fylabel), y=1.05)
plt.xlabel((fxlabel + " [" + fxunit + "]") if fxunit != "1" else fxlabel)
plt.ylabel((fylabel + " [" + fyunit + "]") if fyunit != "1" else fylabel)
Expand Down