Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix the error of pylint in demo directory #1900

Merged
merged 7 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions demo/predict-taxi-trip-duration/script/convert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module of covert data from system stdin"""
import sys
import time, datetime
import time

i = 0
for line in sys.stdin:
if i == 0:
i+=1
i += 1
print(line.strip())
continue
arr = line.strip().split(",")
arr[2] = str(int(time.mktime(time.strptime(arr[2], "%Y-%m-%d %H:%M:%S"))) * 1000)
arr[3] = str(int(time.mktime(time.strptime(arr[3], "%Y-%m-%d %H:%M:%S"))) * 1000)
print(",".join(arr))

29 changes: 12 additions & 17 deletions demo/predict-taxi-trip-duration/script/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module of request predict in script"""
import requests
import os
import base64
import random
import time
import hashlib

url = "http://127.0.0.1:8887/predict"
req ={"id":"id0376262",
"vendor_id":1,
"pickup_datetime":1467302350000,
"dropoff_datetime":1467304896000,
"passenger_count":2,
"pickup_longitude":-73.873093,
"pickup_latitude":40.774097,
"dropoff_longitude":-73.926704,
"dropoff_latitude":40.856739,
"store_and_fwd_flag":"N",
"trip_duration":1}
req = {"id": "id0376262",
"vendor_id": 1,
"pickup_datetime": 1467302350000,
"dropoff_datetime": 1467304896000,
"passenger_count": 2,
"pickup_longitude": -73.873093,
"pickup_latitude": 40.774097,
"dropoff_longitude": -73.926704,
"dropoff_latitude": 40.856739,
"store_and_fwd_flag": "N",
"trip_duration": 1}
r = requests.post(url, json=req)
print(r.text)
55 changes: 32 additions & 23 deletions demo/predict-taxi-trip-duration/script/predict_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module of predict server"""
import numpy as np
import tornado.web
import tornado.ioloop
import json
import lightgbm as lgb
import sqlalchemy as db
import requests
import argparse

bst = None

table_schema = [
("id", "string"),
("vendor_id", "int"),
("pickup_datetime", "timestamp"),
("dropoff_datetime", "timestamp"),
("passenger_count", "int"),
("pickup_longitude", "double"),
("pickup_latitude", "double"),
("dropoff_longitude", "double"),
("dropoff_latitude", "double"),
("store_and_fwd_flag", "string"),
("trip_duration", "int"),
("id", "string"),
("vendor_id", "int"),
("pickup_datetime", "timestamp"),
("dropoff_datetime", "timestamp"),
("passenger_count", "int"),
("pickup_longitude", "double"),
("pickup_latitude", "double"),
("dropoff_longitude", "double"),
("dropoff_latitude", "double"),
("store_and_fwd_flag", "string"),
("trip_duration", "int"),
]

url = ""


def get_schema():
dict_schema = {}
dict_schema_tmp = {}
for i in table_schema:
dict_schema[i[0]] = i[1]
return dict_schema
dict_schema_tmp[i[0]] = i[1]
return dict_schema_tmp


dict_schema = get_schema()
json_schema = json.dumps(dict_schema)


def build_feature(rs):
var_Y = [rs[0]]
var_X = [rs[1:12]]
return np.array(var_X)
var_x = [rs[1:12]]
return np.array(var_x)


class SchemaHandler(tornado.web.RequestHandler):
def get(self):
self.write(json_schema)


class PredictHandler(tornado.web.RequestHandler):
"""Class of PredictHandler docstring."""
def post(self):
row = json.loads(self.request.body)
data = {}
Expand All @@ -72,7 +77,8 @@ def post(self):
row_data.append(row.get(i[0], 0))
else:
row_data.append(None)
data["input"].append(row_data)

data["input"].append(row_data)
rs = requests.post(url, json=data)
result = json.loads(rs.text)
for r in result["data"]["data"]:
Expand All @@ -81,25 +87,28 @@ def post(self):
self.write(str(ins) + "\n")
duration = bst.predict(ins)
self.write("---------------predict trip_duration -------------\n")
self.write("%s s"%str(duration[0]))
self.write(f"{str(duration[0])} s")


class MainHandler(tornado.web.RequestHandler):
def get(self):
self.write("real time execute sparksql demo")


def make_app():
return tornado.web.Application([
(r"/", MainHandler),
(r"/schema", SchemaHandler),
(r"/predict", PredictHandler),
])


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("endpoint", help="specify the endpoint of apiserver")
parser.add_argument("model_path", help="specify the model path")
parser.add_argument("endpoint", help="specify the endpoint of apiserver")
parser.add_argument("model_path", help="specify the model path")
args = parser.parse_args()
url = "http://%s/dbs/demo_db/deployments/demo" % args.endpoint
url = f"http://{args.endpoint}/dbs/demo_db/deployments/demo"
bst = lgb.Booster(model_file=args.model_path)
app = make_app()
app.listen(8887)
Expand Down
19 changes: 9 additions & 10 deletions demo/predict-taxi-trip-duration/script/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module of train and save model"""
import lightgbm as lgb
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("feature_path", help="specify the feature path")
parser.add_argument("model_path", help="specify the model path")
parser.add_argument('feature_path', help='specify the feature path')
parser.add_argument('model_path', help='specify the model path')
args = parser.parse_args()

feature_path = args.feature_path
# merge file
if os.path.isdir(feature_path):
path_list = os.listdir(feature_path)
new_file = "/tmp/merged_feature.csv"
with open(new_file, 'w') as wf:
new_file = '/tmp/merged_feature.csv'
with open(new_file, 'w', encoding='utf-8') as wf:
has_write_header = False
for filename in path_list:
if filename == "_SUCCESS" or filename.startswith('.'):
if filename == '_SUCCESS' or filename.startswith('.'):
continue
with open(os.path.join(feature_path, filename), 'r') as f:
with open(os.path.join(feature_path, filename), 'r', encoding='utf-8') as f:
first_line = True
for line in f.readlines():
if first_line is True:
Expand All @@ -50,7 +49,7 @@
feature_path = new_file

# run batch sql and get instances
df = pd.read_csv(feature_path);
df = pd.read_csv(feature_path)
train_set, predict_set = train_test_split(df, test_size=0.2)
y_train = train_set['trip_duration']
x_train = train_set.drop(columns=['trip_duration'])
Expand Down Expand Up @@ -83,4 +82,4 @@
early_stopping_rounds=5)

gbm.save_model(args.model_path)
print("save model.txt done")
print('save model.txt done')
26 changes: 13 additions & 13 deletions demo/predict-taxi-trip-duration/test/import.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
"""
"""Module of insert data to table"""
import sqlalchemy as db


import sys
import datetime

ddl="""
ddl = """
create table t1(
id string,
vendor_id int,
Expand All @@ -42,26 +39,29 @@
engine = db.create_engine('openmldb:///db_test?zk=127.0.0.1:2181&zkPath=/openmldb')
connection = engine.connect()
try:
connection.execute("create database db_test;");
connection.execute('create database db_test;')
except Exception as e:
aceforeverd marked this conversation as resolved.
Show resolved Hide resolved
print(e)
try:
connection.execute(ddl);
connection.execute(ddl)
except Exception as e:
print(e)


def insert_row(line):
row = line.split(',')
row[2] = '%dl'%int(datetime.datetime.strptime(row[2], '%Y-%m-%d %H:%M:%S').timestamp() * 1000)
row[3] = '%dl'%int(datetime.datetime.strptime(row[3], '%Y-%m-%d %H:%M:%S').timestamp() * 1000)
insert = "insert into t1 values('%s', %s, %s, %s, %s, %s, %s, %s, %s, '%s', %s);"% tuple(row)
row[2] = f"{int(datetime.datetime.strptime(row[2], '%Y-%m-%d %H:%M:%S').timestamp() * 1000)}l"
row[3] = f"{int(datetime.datetime.strptime(row[3], '%Y-%m-%d %H:%M:%S').timestamp() * 1000)}l"
insert = f"insert into t1 values('{row[0]}', {row[1]}, {row[2]}, {row[3]}, {row[4]}, {row[5]}, " \
f"{row[6]}, {row[7]}, {row[8]}, '{row[9]}', {row[10]});"
connection.execute(insert)

with open('data/taxi_tour_table_train_simple.csv', 'r') as fd:

with open('data/taxi_tour_table_train_simple.csv', 'r', encoding='utf-8') as fd:
idx = 0
for line in fd:
for csv_line in fd:
if idx == 0:
idx = idx + 1
continue
insert_row(line.replace('\n', ''))
insert_row(csv_line.replace('\n', ''))
idx = idx + 1
2 changes: 1 addition & 1 deletion demo/talkingdata-adtracking-fraud-detection/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module of request predict in talkingdata-adtracking-fraud-detection"""
import requests

url = "http://127.0.0.1:8881/predict"
Expand Down
Loading