Skip to content

Commit

Permalink
Closes #1689 - Adds support for DataFrame to generic attach (#1798)
Browse files Browse the repository at this point in the history
* Rebase for json argument updates

* DataFrame Generic Attach implementation

* Removing unused imports and flake8 fixes

* Updating based on review comments and segarray updates

Co-authored-by: joshmarshall1 <[email protected]>
  • Loading branch information
joshmarshall1 and joshmarshall1 authored Oct 13, 2022
1 parent 12b5055 commit 27df401
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 13 deletions.
107 changes: 101 additions & 6 deletions arkouda/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,7 @@ def register(self, user_defined_name: str) -> DataFrame:
array(self.columns).register(f"df_columns_{user_defined_name}")

for col, data in self.data.items():
data.register(f"df_data_{col}_{data.objtype}_{user_defined_name}")
data.register(f"df_data_{data.objtype}_{col}_{user_defined_name}")

self.name = user_defined_name
return self
Expand Down Expand Up @@ -2003,8 +2003,8 @@ def attach(user_defined_name: str) -> DataFrame:
columns = dict.fromkeys(json.loads(col_resp))
matches = []
regEx = compile(
f"^df_data_[a-zA-Z0-9]+_({pdarray.objtype}|{Strings.objtype}|"
f"{Categorical.objtype}|{SegArray.objtype})_{user_defined_name}"
f"^df_data_({pdarray.objtype}|{Strings.objtype}|"
f"{Categorical.objtype}|{SegArray.objtype})_.*_{user_defined_name}"
)
# Using the regex, cycle through the registered items and find all the columns in the DataFrame
for name in list_registry():
Expand All @@ -2018,7 +2018,7 @@ def attach(user_defined_name: str) -> DataFrame:
# Remove duplicates caused by multiple components in Categorical or SegArray and
# loop through
for name in set(matches):
colName = name.split("_")[2]
colName = name.split("_")[3]
if f"_{Strings.objtype}_" in name or f"_{pdarray.objtype}_" in name:
cols_resp = cast(str, generic_msg(cmd="attach", args={"name": name}))
dtype = cols_resp.split()[2]
Expand Down Expand Up @@ -2072,8 +2072,8 @@ def unregister_dataframe_by_name(user_defined_name: str) -> None:

matches = []
regEx = compile(
f"^df_data_[a-zA-Z0-9]+_({pdarray.objtype}|{Strings.objtype}|"
f"{Categorical.objtype}|{SegArray.objtype})_{user_defined_name}"
f"^df_data_({pdarray.objtype}|{Strings.objtype}|"
f"{Categorical.objtype}|{SegArray.objtype})_.*_{user_defined_name}"
)
# Using the regex, cycle through the registered items and find all the columns in the DataFrame
for name in list_registry():
Expand Down Expand Up @@ -2102,6 +2102,101 @@ def unregister_dataframe_by_name(user_defined_name: str) -> None:
unregister_pdarray_by_name(f"df_index_{user_defined_name}_key")
Strings.unregister_strings_by_name(f"df_columns_{user_defined_name}")

@staticmethod
def _parse_col_name(entryName, dfName):
"""
Helper method used by from_return_msg to parse the registered name of the data component
and pull out the column type and column name
Parameters
----------
entryName : string
The full registered name of the data component
dfName : string
The name of the DataFrame
Returns
-------
Tuple (columnName, columnType)
"""
regName = entryName.split(" ")[1]
colParts = regName.split("_")
colType = colParts[2]

# Case of '_' in the column or dataframe name
if len(colParts) > 5:
nameInd = regName.rindex(dfName) - 1
startInd = len(colType) + 9
return regName[startInd:nameInd], colType
else:
return colParts[3], colType

@staticmethod
def from_return_msg(repMsg):
"""
Creates and returns a DataFrame based on return components from ak.util.attach
Parameters
----------
repMsg : string
A '+' delimited string of the DataFrame components to parse.
Returns
-------
DataFrame
A DataFrame representing a set of DataFrame components on the server
Raises
------
RuntimeError
Raised if a server-side error is thrown in the process of creating
the DataFrame instance
"""
parts = repMsg.split("+")
dfName = parts[1]
cols = dict.fromkeys(json.loads(parts[2][4:]))

# index could be a pdarray or a Strings
idxType = parts[3].split()[2]
if idxType == Strings.objtype:
idx = Index.factory(Strings.from_return_msg(f"{parts[3]}+{parts[4]}"))
i = 5
else: # pdarray
idx = Index.factory(create_pdarray(parts[3]))
i = 4

# Column parsing
while i < len(parts):
if parts[i][:7] == "created":
colName, colType = DataFrame._parse_col_name(parts[i], dfName)
if colType == "pdarray":
cols[colName] = create_pdarray(parts[i])
else:
cols[colName] = Strings.from_return_msg(f"{parts[i]}+{parts[i+1]}")
i += 1

elif parts[i] == "categorical":
colName = DataFrame._parse_col_name(parts[i + 1], dfName)[0]
catMsg = (
f"{parts[i]}+{parts[i+1]}+{parts[i+2]}+{parts[i+3]}+"
f"{parts[i+4]}+{parts[i+5]}+{parts[i+6]}"
)
cols[colName] = Categorical.from_return_msg(catMsg)
i += 6

elif parts[i] == "segarray":
colName = DataFrame._parse_col_name(parts[i + 1], dfName)[0]
segMsg = f"{parts[i]}+{parts[i+1]}+{parts[i+2]}+{parts[i+3]}"
cols[colName] = SegArray._from_attach_return_msg(segMsg)
i += 3

i += 1

df = DataFrame(cols, idx)
df.name = dfName
return df


def sorted(df, column=False):
"""
Expand Down
11 changes: 8 additions & 3 deletions arkouda/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,20 @@ def attach(name: str, dtype: str = "infer"):
to pull the corresponding parts, otherwise the server will try to infer the type
"""
repMsg = cast(str, generic_msg(cmd="genericAttach", args={"dtype": dtype, "name": name}))
repType = repMsg.split("+")[0]

if repMsg.split("+")[0] == "categorical":
if repType == "categorical":
return Categorical.from_return_msg(repMsg)
elif repMsg.split("+")[0] == "segarray":
elif repType == "segarray":
return SegArray._from_attach_return_msg(repMsg)
elif repMsg.split("+")[0] == "series":
elif repType == "series":
from arkouda.series import Series

return Series.from_return_msg(repMsg)
elif repType == "dataframe":
from arkouda.dataframe import DataFrame

return DataFrame.from_return_msg(repMsg)
else:
dtype = repMsg.split()[2]

Expand Down
6 changes: 4 additions & 2 deletions src/MultiTypeSymbolTable.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,10 @@ module MultiTypeSymbolTable
var regex = compile(pattern);
var infoStr = "";
forall name in tab.keysToArray() with (+ reduce infoStr) {
if regex.match(name) {
infoStr += name + "+";
var match = regex.match(name);
if match.matched {
var end : int = (match.byteOffset: int) + match.numBytes;
infoStr += name[match.byteOffset..#end] + "+";
}
}
return infoStr.strip("+").split("+");
Expand Down
121 changes: 119 additions & 2 deletions src/RegistrationMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ module RegistrationMsg
use Logging;
use Message;
use List;
use Set;

use MultiTypeSymbolTable;
use MultiTypeSymEntry;
use ServerErrorStrings;
use SegmentedString;
use SegmentedMsg;

private config const logLevel = ServerConfig.logLevel;
const regLogger = new Logger(logLevel);
Expand Down Expand Up @@ -278,6 +280,115 @@ module RegistrationMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

/*
Compile the component parts of a DataFrame attach message
:arg cmd: calling command
:type cmd: string
:arg payload: name of SymTab element
:type payload: string
:arg argSize: number of arguments in payload
:type argSize: int
:arg st: SymTab to act on
:type st: borrowed SymTab
:returns: MsgTuple response message
*/
proc attachDataFrameMsg(cmd: string, payload: string, argSize: int,
st: borrowed SymTab): MsgTuple throws {
var msgArgs = parseMessageArgs(payload, argSize);
const name = msgArgs.getValueOf("name");
var colName = "df_columns_%s".format(name);
var repMsg = "dataframe+%s".format(name);

regLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"%s: Collecting DataFrame components for '%s'".format(cmd, name));

var jsonParam = new ParameterObj("name", colName, ObjectType.VALUE, "str");
var json: [0..#1] string = [jsonParam.getJSON()];

// Add columns as a json list
var cols = stringsToJSONMsg(cmd, "%jt".format(json), json.size, st).msg;
repMsg += "+json %s".format(cols);

// Get index
var indParam = new ParameterObj("name", "df_index_%s_key".format(name), ObjectType.VALUE, "");
var indJSON: [0..#1] string = [indParam.getJSON()];
var ind = attachMsg(cmd, "%jt".format(indJSON), indJSON.size, st).msg;
if ind.startsWith("Error:") {
var errorMsg = ind;
regLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
repMsg += "+%s".format(ind);

// Get column data
var nameList = st.findAll("df_data_(pdarray|str|SegArray|Categorical)_.*_%s".format(name));

if nameList.size == 1 && nameList[0] == "" {
var errorMsg = "No data values found for DataFrame %s".format(name);
regLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}

// Convert nameList to a Set to get unique values
var u : set(string) = new set(string, nameList);

regLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"%s: Data components found for dataframe: '%jt'".format(cmd, u));

// Use existing attach functionality to build the response message based on the objType of each data column
forall regName in u with (+ reduce repMsg) {
var parts = regName.split("_");
var objtype: string = parts[2];
var msg: string;
select (objtype){
when ("pdarray") {
var attParam = new ParameterObj("name", regName, ObjectType.VALUE, "");
var attJSON: [0..#1] string = [attParam.getJSON()];
msg = attachMsg(cmd, "%jt".format(attJSON), attJSON.size, st).msg;
}
when ("str") {
var attParam = new ParameterObj("name", regName, ObjectType.VALUE, "");
var attJSON: [0..#1] string = [attParam.getJSON()];
msg = attachMsg(cmd, "%jt".format(attJSON), attJSON.size, st).msg;
}
when ("SegArray") {
msg = attachSegArrayMsg(cmd, regName, st).msg;
}
when ("Categorical") {
msg = attachCategoricalMsg(cmd, regName, st).msg;
}
otherwise {
regLogger.warn(getModuleName(),getRoutineName(),getLineNumber(),
"Unsupported column type found in DataFrame: '%s'. \
Supported types are: pdarray, str, Categorical, and SegArray".format(objtype));

throw getErrorWithContext(
msg="Unknown column type (%s) found in DataFrame: %s".format(objtype, name),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="ValueError"
);
}
}

if (msg.startsWith("Error:")) {
regLogger.error(getModuleName(),getRoutineName(),getLineNumber(),msg);
repMsg = msg;
} else {
repMsg += "+%s".format(msg);
}
}

var msgType = if repMsg.startsWith("Error:") then MsgType.ERROR else MsgType.NORMAL;
return new MsgTuple(repMsg, msgType);
}

/*
Attempt to determine the type of object base on a given name
Expand Down Expand Up @@ -307,6 +418,8 @@ module RegistrationMsg
dtype = "segarray";
} else if st.contains("%s_value".format(name)) && (st.contains("%s_key".format(name)) || st.contains("%s_key_0".format(name))) {
dtype = "series";
} else if st.contains("df_columns_%s".format(name)) && (st.contains("df_index_%s_key".format(name))) {
dtype = "dataframe";
} else {
throw getErrorWithContext(
msg="Unable to determine type for given name: %s".format(name),
Expand Down Expand Up @@ -351,9 +464,10 @@ module RegistrationMsg
dtype = "simple";
}

var json: [0..#1] string = [msgArgs.get("name").getJSON()];

select (dtype.toLower()) {
when ("simple") {
var json: [0..#1] string = [msgArgs.get("name").getJSON()];
// pdarray and strings can use the attachMsg method
return attachMsg(cmd, "%jt".format(json), json.size, st);
}
Expand All @@ -366,9 +480,12 @@ module RegistrationMsg
when ("series") {
return attachSeriesMsg(cmd, name, st);
}
when ("dataframe") {
return attachDataFrameMsg(cmd, "%jt".format(json), json.size, st);
}
otherwise {
regLogger.warn(getModuleName(),getRoutineName(),getLineNumber(),
"Unsupported type provided: '%s'. Supported types are: pdarray, strings, categorical, segarray, and series".format(dtype));
"Unsupported type provided: '%s'. Supported types are: pdarray, strings, categorical, segarray, series, and dataframe".format(dtype));

throw getErrorWithContext(
msg="Unknown type (%s) supplied for given name: %s".format(dtype, name),
Expand Down
29 changes: 29 additions & 0 deletions tests/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,35 @@ def test_series_attach(self):
s2Eq = s2_attach.index == s2.index
self.assertTrue(all(s2Eq.to_ndarray()))

def test_dataframe_attach(self):
ind = ak.Index.factory(ak.array(["a", "b", "c", "1", "2", "3"]))
item = ak.array([0, 0, 1, 1, 2, 0])
amount = ak.array([0.5, 0.6, 1.1, 1.2, 4.3, 0.6])
userid = ak.array([111, 222, 111, 333, 222, 111])
username = ak.array(["Alice", "Bob", "Alice", "Carol", "Bob", "Alice"])
cat = ak.Categorical(username)

a = [2, 1, 2, 3]
b = [2, 2, 3]
c = [2, 1, 2]
d = [2, 1, 3]
e = [1, 2, 3]
f = [2, 3, 1]

flat = a + b + c + d + e + f
akflat = ak.array(flat)
segments = ak.array([0, 4, 7, 10, 13, 16])
segarr = ak.segarray(segments, akflat)

df = ak.DataFrame(
{"username": cat, "user_ID": userid, "item": item, "amount": amount, "visits": segarr}, ind
)
df.register("df_test")

dfAtt = ak.util.attach("df_test")

self.assertTrue(df.to_pandas().equals(dfAtt.to_pandas()))

def test_unregister_by_name(self):
# Register the four supported object types
# pdarray
Expand Down

0 comments on commit 27df401

Please sign in to comment.