Skip to content

Commit

Permalink
Merge pull request #61 from emilybache/with_tests
Browse files Browse the repository at this point in the history
Python translation
  • Loading branch information
martinsson authored Mar 12, 2023
2 parents 1716e7d + 0fbfe8c commit 57cfc5d
Show file tree
Hide file tree
Showing 7 changed files with 360 additions and 0 deletions.
57 changes: 57 additions & 0 deletions .github/workflows/python-build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name: python-build

env:
PROJECT_DIR: python

on:
push:
paths:
- 'python/**'
- '.github/workflows/python-build.yml'
pull_request:
paths:
- 'python/**'
- '.github/workflows/python-build.yml'

jobs:
build:
defaults:
run:
working-directory: ./${{ env.PROJECT_DIR }}

runs-on: ubuntu-22.04

env:
DB_USER: root
DB_OLD_PASSWORD: root
DB_PASSWORD: mysql

steps:
- name: Checkout Repository
uses: actions/checkout@v2

- name: Start MYSQL and import DB
run: |
sudo systemctl start mysql
mysqladmin --user=${{ env.DB_USER }} --password=${{ env.DB_OLD_PASSWORD }} version
mysqladmin --user=${{ env.DB_USER }} --password=${{ env.DB_OLD_PASSWORD }} password ${{ env.DB_PASSWORD }}
mysql -u${{ env.DB_USER }} -p${{ env.DB_PASSWORD }} < ${GITHUB_WORKSPACE}/database/initDatabase.sql
- name: Install MySQL odbc driver
run: |
wget https://repo.mysql.com/apt/ubuntu/pool/mysql-8.0/m/mysql-community/mysql-community-client-plugins_8.0.32-1ubuntu22.04_amd64.deb
sudo dpkg -i mysql-community-client-plugins_8.0.32-1ubuntu22.04_amd64.deb
wget https://dev.mysql.com/get/Downloads/Connector-ODBC/8.0/mysql-connector-odbc_8.0.32-1ubuntu22.04_amd64.deb
sudo dpkg -i mysql-connector-odbc_8.0.32-1ubuntu22.04_amd64.deb
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Set up dependencies
run: pip install -r requirements.txt

- name: Test
run: PYTHONPATH=src python -m pytest

3 changes: 3 additions & 0 deletions python/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*/.pytest_cache
venv
**/__pycache__
28 changes: 28 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Python version of Lift Pass Pricing Kata

As with the other language versions, this exercise requires a database. There is a description in the [top level README](../README.md) of how to set up MySQL. If you don't have that, this version should fall back on sqlite3, and create a local database file 'lift_pass.db' in the directory where you run the application. Unfortunately the code doesn't actually work properly with sqlite3, so you'll have to adjust the SQL statements in prices.py.

For this python version you will also need to install the dependencies. I recommend you install them in a virtual environment like this:

python -m venv venv

Check the [Python documentation](https://docs.python.org/3/library/venv.html) for how to activate this environment on your platform. Then install the requirements:

python -m pip install -r requirements.txt

You can start the application like this:

cd src
python -m prices

Note there is no webpage on the default url - try this url as an example to check it's running: http://127.0.0.1:3005/prices?type=1jour

You can run the tests with pytest:

PYTHONPATH=src python -m pytest

or on Windows Powershell:

$env:PYTHONPATH='src'; python -m pytest


5 changes: 5 additions & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Flask
pytest
requests
pyodbc
PyMySQL
94 changes: 94 additions & 0 deletions python/src/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pathlib import Path


def create_lift_pass_db_connection(connection_options):
connection_functions = [
try_to_connect_with_odbc,
try_to_connect_with_pymysql,
try_to_connect_with_sqlite3,
]
for fun in connection_functions:
try:
connection = fun(connection_options)
if connection is not None:
return connection
except Exception as e:
print(f"unable to connect to db with {fun}")
raise RuntimeError("Unable to connect to the database.")


def try_to_connect_with_sqlite3(connection_options):
import sqlite3
connection = sqlite3.connect("lift_pass.db")
create_statements = [
"""CREATE TABLE IF NOT EXISTS base_price (
pass_id INTEGER PRIMARY KEY AUTOINCREMENT,
type VARCHAR(255) NOT NULL,
cost INTEGER NOT NULL
);""",
"""INSERT INTO base_price (type, cost) VALUES ('1jour', 35);""",
"""INSERT INTO base_price (type, cost) VALUES ('night', 19);""",
"""CREATE TABLE IF NOT EXISTS holidays (
holiday DATE NOT NULL,
description VARCHAR(255) NOT NULL
);""",
"INSERT INTO holidays (holiday, description) VALUES ('2019-02-18', 'winter');",
"INSERT INTO holidays (holiday, description) VALUES ('2019-02-25', 'winter');",
"INSERT INTO holidays (holiday, description) VALUES ('2019-03-04', 'winter');",
]
for statement in create_statements:
connection.execute(statement)

return connection


def try_to_connect_with_pymysql(connection_options):
import pymysql.cursors

class PyMySQLCursorWrapper(pymysql.cursors.Cursor):
"""
The pymysql.cursors.Cursor class very nearly works the same as the odbc equivalent. Unfortunately it doesn't
understand the '?' in a SQL statement as an argument placeholder, and instead uses '%s'. This wrapper fixes that.
"""
def mogrify(self, query: str, args: object = ...) -> str:
query = query.replace('?', '%s')
return super().mogrify(query, args)

connection = pymysql.connect(host=connection_options["host"],
user=connection_options["user"],
password=connection_options["password"],
database=connection_options["database"],
cursorclass=PyMySQLCursorWrapper)

return connection


def try_to_connect_with_odbc(connection_options):
driver = get_mariadb_driver()
if driver:
import pyodbc
connection_string = make_connection_string_template(driver) % (
connection_options["host"],
connection_options["user"],
connection_options["database"],
connection_options["password"],
)
return pyodbc.connect(connection_string)
return None


def get_mariadb_driver():
import pyodbc
drivers = []
for driver in pyodbc.drivers():
if driver.startswith("MySQL") or driver.startswith("MariaDB"):
drivers.append(driver)

if drivers:
return max(drivers)
else:
return None


def make_connection_string_template(driver):
return 'DRIVER={' + driver + '};SERVER=%s;USER=%s;OPTION=3;DATABASE=%s;PASSWORD=%s'
81 changes: 81 additions & 0 deletions python/src/prices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import math

from flask import Flask
from flask import request
from datetime import datetime
from db import create_lift_pass_db_connection

app = Flask("lift-pass-pricing")

connection_options = {
"host": 'localhost',
"user": 'root',
"database": 'lift_pass',
"password": 'mysql'}

connection = None

@app.route("/prices", methods=['GET', 'PUT'])
def prices():
res = {}
global connection
if connection is None:
connection = create_lift_pass_db_connection(connection_options)
if request.method == 'PUT':
lift_pass_cost = request.args["cost"]
lift_pass_type = request.args["type"]
cursor = connection.cursor()
cursor.execute('INSERT INTO `base_price` (type, cost) VALUES (?, ?) ' +
'ON DUPLICATE KEY UPDATE cost = ?', (lift_pass_type, lift_pass_cost, lift_pass_cost))
return {}
elif request.method == 'GET':
cursor = connection.cursor()
cursor.execute(f'SELECT cost FROM base_price '
+ 'WHERE type = ? ', (request.args['type'],))
row = cursor.fetchone()
result = {"cost": row[0]}
if 'age' in request.args and request.args.get('age', type=int) < 6:
res["cost"] = 0
else:
if "type" in request.args and request.args["type"] != "night":
cursor = connection.cursor()
cursor.execute('SELECT * FROM holidays')
is_holiday = False
reduction = 0
for row in cursor.fetchall():
holiday = row[0]
if "date" in request.args:
d = datetime.fromisoformat(request.args["date"])
if d.year == holiday.year and d.month == holiday.month and holiday.day == d.day:
is_holiday = True
if not is_holiday and "date" in request.args and datetime.fromisoformat(request.args["date"]).weekday() == 0:
reduction = 35

# TODO: apply reduction for others
if 'age' in request.args and request.args.get('age', type=int) < 15:
res['cost'] = math.ceil(result["cost"]*.7)
else:
if 'age' not in request.args:
cost = result['cost'] * (1 - reduction/100)
res['cost'] = math.ceil(cost)
else:
if 'age' in request.args and request.args.get('age', type=int) > 64:
cost = result['cost'] * .75 * (1 - reduction / 100)
res['cost'] = math.ceil(cost)
elif 'age' in request.args:
cost = result['cost'] * (1 - reduction / 100)
res['cost'] = math.ceil(cost)
else:
if 'age' in request.args and request.args.get('age', type=int) >= 6:
if request.args.get('age', type=int) > 64:
res['cost'] = math.ceil(result['cost'] * .4)
else:
res.update(result)
else:
res['cost'] = 0

return res


if __name__ == "__main__":
app.run(port=3005)
92 changes: 92 additions & 0 deletions python/test/test_prices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import multiprocessing

import pytest
import requests
from datetime import datetime
import time

from prices import app

TEST_PORT = 3006


def server(port):
app.run(port=port)


def wait_for_server_to_start(server_url):
started = False
while not started:
try:
requests.get(server_url)
started = True
except Exception as e:
time.sleep(0.2)


@pytest.fixture(autouse=True, scope="session")
def lift_pass_pricing_app():
""" starts the lift pass pricing flask app running on localhost """
p = multiprocessing.Process(target=server, args=(TEST_PORT,))
p.start()
server_url = f"http://127.0.0.1:{TEST_PORT}"
wait_for_server_to_start(server_url)
yield server_url
p.terminate()


def test_put_1jour_price(lift_pass_pricing_app):
response = requests.put(lift_pass_pricing_app + '/prices', params={'type': '1jour', 'cost': 35})
assert response.status_code == 200


def test_put_night_price(lift_pass_pricing_app):
response = requests.put(lift_pass_pricing_app + '/prices', params={'type': 'night', 'cost': 19})
assert response.status_code == 200


def test_default_cost(lift_pass_pricing_app):
response = requests.get(lift_pass_pricing_app + "/prices", params={'type': '1jour'})
assert response.json() == {'cost': 35}


@pytest.mark.parametrize(
"age,expectedCost", [
(5, 0),
(6, 25),
(14, 25),
(15, 35),
(25, 35),
(64, 35),
(65, 27),
])
def test_works_for_all_ages(lift_pass_pricing_app, age, expectedCost):
response = requests.get(lift_pass_pricing_app + "/prices", params={'type': '1jour', 'age': age})
assert response.json() == {'cost': expectedCost}


@pytest.mark.parametrize(
"age,expectedCost", [
(5, 0),
(6, 19),
(25, 19),
(64, 19),
(65, 8),
])
def test_works_for_night_passes(lift_pass_pricing_app, age, expectedCost):
response = requests.get(lift_pass_pricing_app + "/prices", params={'type': 'night', 'age': age})
assert response.json() == {'cost': expectedCost}


@pytest.mark.parametrize(
"age,expectedCost,ski_date", [
(15, 35, datetime.fromisoformat('2019-02-22')),
(15, 35, datetime.fromisoformat('2019-02-25')), # monday, holiday
(15, 23, datetime.fromisoformat('2019-03-11')), # monday
(65, 18, datetime.fromisoformat('2019-03-11')), # monday
])
def test_works_for_monday_deals(lift_pass_pricing_app, age, expectedCost, ski_date):
response = requests.get(lift_pass_pricing_app + "/prices", params={'type': '1jour', 'age': age, 'date': ski_date})
assert response.json() == {'cost': expectedCost}

# TODO 2-4, and 5, 6 day pass

0 comments on commit 57cfc5d

Please sign in to comment.