aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Meinhold <mxmeinhold@gmail.com>2021-03-28 00:17:55 -0400
committerMax Meinhold <mxmeinhold@gmail.com>2021-04-01 22:25:08 -0400
commit73e55ac8b0f2e58de681afac55ea5d38507c609e (patch)
tree5deed7ec23f678a984cac70e3cdd1d0ca804207a
parentc6af9137de15bbe20362033b9e7e7ffecca055f1 (diff)
Add type hints and mypy
-rw-r--r--.github/workflows/python-app.yml24
-rw-r--r--README.md9
-rw-r--r--packet/__init__.py2
-rw-r--r--packet/commands.py26
-rw-r--r--packet/context_processors.py15
-rw-r--r--packet/git.py6
-rw-r--r--packet/ldap.py55
-rw-r--r--packet/log_utils.py16
-rw-r--r--packet/mail.py12
-rw-r--r--packet/models.py104
-rw-r--r--packet/notifications.py36
-rw-r--r--packet/stats.py45
-rw-r--r--packet/utils.py32
-rw-r--r--requirements.txt10
-rw-r--r--setup.cfg2
15 files changed, 236 insertions, 158 deletions
diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml
index 763c63e..3934f40 100644
--- a/.github/workflows/python-app.yml
+++ b/.github/workflows/python-app.yml
@@ -32,3 +32,27 @@ jobs:
- name: Lint with pylint
run: |
pylint packet
+
+ typecheck:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: [3.9]
+
+ steps:
+ - name: Install ldap dependencies
+ run: sudo apt-get update && sudo apt-get install libldap2-dev libsasl2-dev
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
+ - name: Typecheck with mypy
+ run: |
+ # Disabled error codes to discard errors from imports
+ mypy --disable-error-code import --disable-error-code name-defined --disallow-untyped-defs --exclude routes packet
diff --git a/README.md b/README.md
index 72355f3..fb398ba 100644
--- a/README.md
+++ b/README.md
@@ -115,13 +115,14 @@ All DB commands are from the `Flask-Migrate` library and are used to configure D
docs [here](https://flask-migrate.readthedocs.io/en/latest/) for details.
## Code standards
-This project is configured to use Pylint. Commits will be pylinted by GitHub actions and if the score drops your build will
-fail blocking you from merging. To make your life easier just run it before making a PR.
+This project is configured to use Pylint and mypy. Commits will be pylinted and typechecked by GitHub actions and if the
+score drops your build will fail blocking you from merging. To make your life easier just run it before making a PR.
-To run pylint use this command:
+To run pylint and mypy use these commands:
```bash
pylint packet/routes packet
+mypy --disable-error-code import --disable-error-code name-defined --disallow-untyped-defs --exclude routes packet
```
All python files should have a top-level docstring explaining the contents of the file and complex functions should
-have docstrings explaining any non-obvious portions.
+have docstrings explaining any non-obvious portions. Functions should have type annotations.
diff --git a/packet/__init__.py b/packet/__init__.py
index ac6abd0..898429c 100644
--- a/packet/__init__.py
+++ b/packet/__init__.py
@@ -21,7 +21,7 @@ from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from .git import get_version
-app = Flask(__name__)
+app: Flask = Flask(__name__)
gzip = Gzip(app)
# Load default configuration and any environment variable overrides
diff --git a/packet/commands.py b/packet/commands.py
index 32dac8e..ea3591a 100644
--- a/packet/commands.py
+++ b/packet/commands.py
@@ -5,7 +5,7 @@ Defines command-line utilities for use with packet
import sys
from secrets import token_hex
-from datetime import datetime, time
+from datetime import datetime, time, date
import csv
import click
@@ -15,7 +15,7 @@ from .utils import sync_freshman, create_new_packets, sync_with_ldap
@app.cli.command('create-secret')
-def create_secret():
+def create_secret() -> None:
"""
Generates a securely random token. Useful for creating a value for use in the "SECRET_KEY" config setting.
"""
@@ -28,13 +28,13 @@ packet_end_time = time(hour=21)
class CSVFreshman:
- def __init__(self, row):
+ def __init__(self, row: list[str]) -> None:
self.name = row[0].strip()
self.rit_username = row[3].strip()
self.onfloor = row[1].strip() == 'TRUE'
-def parse_csv(freshmen_csv):
+def parse_csv(freshmen_csv: str) -> dict[str, CSVFreshman]:
print('Parsing file...')
try:
with open(freshmen_csv, newline='') as freshmen_csv_file:
@@ -44,7 +44,7 @@ def parse_csv(freshmen_csv):
raise e
-def input_date(prompt):
+def input_date(prompt: str) -> date:
while True:
try:
date_str = input(prompt + ' (format: MM/DD/YYYY): ')
@@ -55,7 +55,7 @@ def input_date(prompt):
@app.cli.command('sync-freshmen')
@click.argument('freshmen_csv')
-def sync_freshmen(freshmen_csv):
+def sync_freshmen(freshmen_csv: str) -> None:
"""
Updates the freshmen entries in the DB to match the given CSV.
"""
@@ -68,7 +68,7 @@ def sync_freshmen(freshmen_csv):
@app.cli.command('create-packets')
@click.argument('freshmen_csv')
-def create_packets(freshmen_csv):
+def create_packets(freshmen_csv: str) -> None:
"""
Creates a new packet season for each of the freshmen in the given CSV.
"""
@@ -84,7 +84,7 @@ def create_packets(freshmen_csv):
@app.cli.command('ldap-sync')
-def ldap_sync():
+def ldap_sync() -> None:
"""
Updates the upper and misc sigs in the DB to match ldap.
"""
@@ -97,7 +97,7 @@ def ldap_sync():
help='The file to write to. If no file provided, output is sent to stdout.')
@click.option('--csv/--no-csv', 'use_csv', required=False, default=False, help='Format output as comma separated list.')
@click.option('--date', 'date_str', required=False, default='', help='Packet end date in the format MM/DD/YYYY.')
-def fetch_results(file_path, use_csv, date_str):
+def fetch_results(file_path: str, use_csv: bool, date_str: str) -> None:
"""
Fetches and prints the results from a given packet season.
"""
@@ -150,7 +150,7 @@ def fetch_results(file_path, use_csv, date_str):
@app.cli.command('extend-packet')
@click.argument('packet_id')
-def extend_packet(packet_id):
+def extend_packet(packet_id: int) -> None:
"""
Extends the given packet by setting a new end date.
"""
@@ -168,7 +168,7 @@ def extend_packet(packet_id):
print('Packet successfully extended')
-def remove_sig(packet_id, username, is_member):
+def remove_sig(packet_id: int, username: str, is_member: bool) -> None:
packet = Packet.by_id(packet_id)
if not packet.is_open():
@@ -200,7 +200,7 @@ def remove_sig(packet_id, username, is_member):
@app.cli.command('remove-member-sig')
@click.argument('packet_id')
@click.argument('member')
-def remove_member_sig(packet_id, member):
+def remove_member_sig(packet_id: int, member: str) -> None:
"""
Removes the given member's signature from the given packet.
:param member: The member's CSH username
@@ -211,7 +211,7 @@ def remove_member_sig(packet_id, member):
@app.cli.command('remove-freshman-sig')
@click.argument('packet_id')
@click.argument('freshman')
-def remove_freshman_sig(packet_id, freshman):
+def remove_freshman_sig(packet_id: int, freshman: str) -> None:
"""
Removes the given freshman's signature from the given packet.
:param freshman: The freshman's RIT username
diff --git a/packet/context_processors.py b/packet/context_processors.py
index bff75b1..93ab115 100644
--- a/packet/context_processors.py
+++ b/packet/context_processors.py
@@ -5,14 +5,15 @@ import hashlib
import urllib
from functools import lru_cache
from datetime import datetime
+from typing import Callable
-from packet.models import Freshman
+from packet.models import Freshman, UpperSignature
from packet import app, ldap
# pylint: disable=bare-except
@lru_cache(maxsize=128)
-def get_csh_name(username):
+def get_csh_name(username: str) -> str:
try:
member = ldap.get_member(username)
return member.cn + ' (' + member.uid + ')'
@@ -20,7 +21,7 @@ def get_csh_name(username):
return username
-def get_roles(sig):
+def get_roles(sig: UpperSignature) -> dict[str, str]:
"""
Converts a signature's role fields to a dict for ease of access.
:return: A dictionary of role short names to role long names
@@ -45,7 +46,7 @@ def get_roles(sig):
# pylint: disable=bare-except
@lru_cache(maxsize=256)
-def get_rit_name(username):
+def get_rit_name(username: str) -> str:
try:
freshman = Freshman.query.filter_by(rit_username=username).first()
return freshman.name + ' (' + username + ')'
@@ -55,7 +56,7 @@ def get_rit_name(username):
# pylint: disable=bare-except
@lru_cache(maxsize=256)
-def get_rit_image(username):
+def get_rit_image(username: str) -> str:
if username:
addresses = [username + '@rit.edu', username + '@g.rit.edu']
for addr in addresses:
@@ -69,7 +70,7 @@ def get_rit_image(username):
return 'https://www.gravatar.com/avatar/freshmen?d=mp&f=y'
-def log_time(label):
+def log_time(label: str) -> None:
"""
Used during debugging to log timestamps while rendering templates
"""
@@ -77,7 +78,7 @@ def log_time(label):
@app.context_processor
-def utility_processor():
+def utility_processor() -> dict[str, Callable]:
return dict(
get_csh_name=get_csh_name, get_rit_name=get_rit_name, get_rit_image=get_rit_image, log_time=log_time,
get_roles=get_roles
diff --git a/packet/git.py b/packet/git.py
index 00e4d65..506276d 100644
--- a/packet/git.py
+++ b/packet/git.py
@@ -2,7 +2,7 @@ import json
import os
import subprocess
-def get_short_sha(commit_ish: str = 'HEAD'):
+def get_short_sha(commit_ish: str = 'HEAD') -> str:
"""
Get the short hash of a commit-ish
Returns '' if unfound
@@ -14,7 +14,7 @@ def get_short_sha(commit_ish: str = 'HEAD'):
except subprocess.CalledProcessError:
return ''
-def get_tag(commit_ish: str = 'HEAD'):
+def get_tag(commit_ish: str = 'HEAD') -> str:
"""
Get the name of the tag at a given commit-ish
Returns '' if untagged
@@ -26,7 +26,7 @@ def get_tag(commit_ish: str = 'HEAD'):
except subprocess.CalledProcessError:
return ''
-def get_version(commit_ish: str = 'HEAD'):
+def get_version(commit_ish: str = 'HEAD') -> str:
"""
Get the version string of a commit-ish
diff --git a/packet/ldap.py b/packet/ldap.py
index 99b0367..f276484 100644
--- a/packet/ldap.py
+++ b/packet/ldap.py
@@ -4,8 +4,9 @@ Helper functions for working with the csh_ldap library
from functools import lru_cache
from datetime import date
+from typing import Optional, cast, Any
-from csh_ldap import CSHLDAP
+from csh_ldap import CSHLDAP, CSHMember
from packet import app
@@ -20,32 +21,32 @@ class MockMember:
self.cn = cn if cn else uid.title() # pylint: disable=invalid-name
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if type(other) is type(self):
return self.uid == other.uid
return False
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.uid)
- def __repr__(self):
+ def __repr__(self) -> str:
return f'MockMember(uid: {self.uid}, groups: {self.groups})'
class LDAPWrapper:
- def __init__(self, cshldap=None, mock_members=None):
+ def __init__(self, cshldap: CSHLDAP = None, mock_members: list[MockMember] = None):
self.ldap = cshldap
- self.mock_members = mock_members
+ self.mock_members = cast(list[MockMember], mock_members)
if self.ldap:
app.logger.info('LDAP configured with CSH LDAP')
else:
app.logger.info('LDAP configured with local mock')
- def _get_group_members(self, group):
+ def _get_group_members(self, group: str) -> list[CSHMember]:
"""
:return: A list of CSHMember instances
"""
@@ -55,7 +56,7 @@ class LDAPWrapper:
return list(filter(lambda member: group in member.groups, self.mock_members))
- def _is_member_of_group(self, member, group):
+ def _is_member_of_group(self, member: CSHMember, group: str) -> bool:
"""
:param member: A CSHMember instance
"""
@@ -67,7 +68,7 @@ class LDAPWrapper:
else:
return group in member.groups
- def get_groups(self, member):
+ def get_groups(self, member: CSHMember) -> list[str]:
if self.ldap:
return list(
map(
@@ -89,7 +90,7 @@ class LDAPWrapper:
# Getters
@lru_cache(maxsize=256)
- def get_member(self, username):
+ def get_member(self, username: str) -> CSHMember:
"""
:return: A CSHMember instance
"""
@@ -102,7 +103,7 @@ class LDAPWrapper:
raise KeyError('Invalid Search Name')
- def get_active_members(self):
+ def get_active_members(self) -> list[CSHMember]:
"""
Gets all current, dues-paying members
:return: A list of CSHMember instances
@@ -110,7 +111,7 @@ class LDAPWrapper:
return self._get_group_members('active')
- def get_intro_members(self):
+ def get_intro_members(self) -> list[CSHMember]:
"""
Gets all freshmen members
:return: A list of CSHMember instances
@@ -118,7 +119,7 @@ class LDAPWrapper:
return self._get_group_members('intromembers')
- def get_eboard(self):
+ def get_eboard(self) -> list[CSHMember]:
"""
Gets all voting members of eboard
:return: A list of CSHMember instances
@@ -132,7 +133,7 @@ class LDAPWrapper:
return members
- def get_live_onfloor(self):
+ def get_live_onfloor(self) -> list[CSHMember]:
"""
All upperclassmen who live on floor and are not eboard
:return: A list of CSHMember instances
@@ -146,7 +147,7 @@ class LDAPWrapper:
return members
- def get_active_rtps(self):
+ def get_active_rtps(self) -> list[CSHMember]:
"""
All active RTPs
:return: A list of CSHMember instances
@@ -154,7 +155,7 @@ class LDAPWrapper:
return [member.uid for member in self._get_group_members('active_rtp')]
- def get_3das(self):
+ def get_3das(self) -> list[CSHMember]:
"""
All 3das
:return: A list of CSHMember instances
@@ -162,7 +163,7 @@ class LDAPWrapper:
return [member.uid for member in self._get_group_members('3da')]
- def get_webmasters(self):
+ def get_webmasters(self) -> list[CSHMember]:
"""
All webmasters
:return: A list of CSHMember instances
@@ -170,14 +171,14 @@ class LDAPWrapper:
return [member.uid for member in self._get_group_members('webmaster')]
- def get_constitutional_maintainers(self):
+ def get_constitutional_maintainers(self) -> list[CSHMember]:
"""
All constitutional maintainers
:return: A list of CSHMember instances
"""
return [member.uid for member in self._get_group_members('constitutional_maintainers')]
- def get_wiki_maintainers(self):
+ def get_wiki_maintainers(self) -> list[CSHMember]:
"""
All wiki maintainers
:return: A list of CSHMember instances
@@ -185,7 +186,7 @@ class LDAPWrapper:
return [member.uid for member in self._get_group_members('wiki_maintainers')]
- def get_drink_admins(self):
+ def get_drink_admins(self) -> list[CSHMember]:
"""
All drink admins
:return: A list of CSHMember instances
@@ -193,7 +194,7 @@ class LDAPWrapper:
return [member.uid for member in self._get_group_members('drink')]
- def get_eboard_role(self, member):
+ def get_eboard_role(self, member: CSHMember) -> Optional[str]:
"""
:param member: A CSHMember instance
:return: A String or None
@@ -224,29 +225,29 @@ class LDAPWrapper:
# Status checkers
- def is_eboard(self, member):
+ def is_eboard(self, member: CSHMember) -> bool:
"""
:param member: A CSHMember instance
"""
return self._is_member_of_group(member, 'eboard')
- def is_evals(self, member):
+ def is_evals(self, member: CSHMember) -> bool:
return self._is_member_of_group(member, 'eboard-evaluations')
- def is_rtp(self, member):
+ def is_rtp(self, member: CSHMember) -> bool:
return self._is_member_of_group(member, 'rtp')
- def is_intromember(self, member):
+ def is_intromember(self, member: CSHMember) -> bool:
"""
:param member: A CSHMember instance
"""
return self._is_member_of_group(member, 'intromembers')
- def is_on_coop(self, member):
+ def is_on_coop(self, member: CSHMember) -> bool:
"""
:param member: A CSHMember instance
"""
@@ -256,7 +257,7 @@ class LDAPWrapper:
return self._is_member_of_group(member, 'spring_coop')
- def get_roomnumber(self, member): # pylint: disable=no-self-use
+ def get_roomnumber(self, member: CSHMember) -> Optional[int]: # pylint: disable=no-self-use
"""
:param member: A CSHMember instance
"""
diff --git a/packet/log_utils.py b/packet/log_utils.py
index 5481bef..2d69f16 100644
--- a/packet/log_utils.py
+++ b/packet/log_utils.py
@@ -4,18 +4,20 @@ General utilities for logging metadata
from functools import wraps
from datetime import datetime
+from typing import Any, Callable, TypeVar, cast
from packet import app, ldap
from packet.context_processors import get_rit_name
from packet.utils import is_freshman_on_floor
+WrappedFunc = TypeVar('WrappedFunc', bound=Callable)
-def log_time(func):
+def log_time(func: WrappedFunc) -> WrappedFunc:
"""
Decorator for logging the execution time of a function
"""
@wraps(func)
- def wrapped_function(*args, **kwargs):
+ def wrapped_function(*args: list, **kwargs: dict) -> Any:
start = datetime.now()
result = func(*args, **kwargs)
@@ -25,10 +27,10 @@ def log_time(func):
return result
- return wrapped_function
+ return cast(WrappedFunc, wrapped_function)
-def _format_cache(func):
+def _format_cache(func: Any) -> str:
"""
:return: The output of func.cache_info() as a compactly formatted string
"""
@@ -41,17 +43,17 @@ def _format_cache(func):
_caches = (get_rit_name, ldap.get_member, is_freshman_on_floor)
-def log_cache(func):
+def log_cache(func: WrappedFunc) -> WrappedFunc:
"""
Decorator for logging cache info
"""
@wraps(func)
- def wrapped_function(*args, **kwargs):
+ def wrapped_function(*args: list, **kwargs: dict) -> Any:
result = func(*args, **kwargs)
app.logger.info('Cache stats: ' + ', '.join(map(_format_cache, _caches)))
return result
- return wrapped_function
+ return cast(WrappedFunc, wrapped_function)
diff --git a/packet/mail.py b/packet/mail.py
index b5a4f12..5aa32f5 100644
--- a/packet/mail.py
+++ b/packet/mail.py
@@ -1,12 +1,19 @@
+from typing import TypedDict
+
from flask import render_template
from flask_mail import Mail, Message
from packet import app
+from packet.models import Packet
mail = Mail(app)
-def send_start_packet_mail(packet):
+class ReportForm(TypedDict):
+ person: str
+ report: str
+
+def send_start_packet_mail(packet: Packet) -> None:
if app.config['MAIL_PROD']:
recipients = ['<' + packet.freshman.rit_username + '@rit.edu>']
msg = Message(subject='CSH Packet Starts ' + packet.start.strftime('%A, %B %-d'),
@@ -19,8 +26,7 @@ def send_start_packet_mail(packet):
app.logger.info('Sending mail to ' + recipients[0])
mail.send(msg)
-
-def send_report_mail(form_results, reporter):
+def send_report_mail(form_results: ReportForm, reporter: str) -> None:
if app.config['MAIL_PROD']:
recipients = ['<evals@csh.rit.edu>']
msg = Message(subject='Packet Report',
diff --git a/packet/models.py b/packet/models.py
index b914d27..f22e467 100644
--- a/packet/models.py
+++ b/packet/models.py
@@ -4,6 +4,7 @@ Defines the application's database models
from datetime import datetime
from itertools import chain
+from typing import cast, Optional
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Boolean
from sqlalchemy.orm import relationship
@@ -18,7 +19,7 @@ class SigCounts:
"""
Utility class for returning counts of signatures broken out by type
"""
- def __init__(self, upper, fresh, misc):
+ def __init__(self, upper: int, fresh: int, misc: int):
# Base fields
self.upper = upper
self.fresh = fresh
@@ -34,23 +35,23 @@ class SigCounts:
class Freshman(db.Model):
__tablename__ = 'freshman'
- rit_username = Column(String(10), primary_key=True)
- name = Column(String(64), nullable=False)
- onfloor = Column(Boolean, nullable=False)
- fresh_signatures = relationship('FreshSignature')
+ rit_username = cast(str, Column(String(10), primary_key=True))
+ name = cast(str, Column(String(64), nullable=False))
+ onfloor = cast(bool, Column(Boolean, nullable=False))
+ fresh_signatures = cast('FreshSignature', relationship('FreshSignature'))
# One freshman can have multiple packets if they repeat the intro process
- packets = relationship('Packet', order_by='desc(Packet.id)')
+ packets = cast('Packet', relationship('Packet', order_by='desc(Packet.id)'))
@classmethod
- def by_username(cls, username: str):
+ def by_username(cls, username: str) -> 'Packet':
"""
Helper method to retrieve a freshman by their RIT username
"""
return cls.query.filter_by(rit_username=username).first()
@classmethod
- def get_all(cls):
+ def get_all(cls) -> list['Packet']:
"""
Helper method to get all freshmen easily
"""
@@ -59,25 +60,26 @@ class Freshman(db.Model):
class Packet(db.Model):
__tablename__ = 'packet'
- id = Column(Integer, primary_key=True, autoincrement=True)
- freshman_username = Column(ForeignKey('freshman.rit_username'))
- start = Column(DateTime, nullable=False)
- end = Column(DateTime, nullable=False)
+ id = cast(int, Column(Integer, primary_key=True, autoincrement=True))
+ freshman_username = cast(str, Column(ForeignKey('freshman.rit_username')))
+ start = cast(datetime, Column(DateTime, nullable=False))
+ end = cast(datetime, Column(DateTime, nullable=False))
- freshman = relationship('Freshman', back_populates='packets')
+ freshman = cast(Freshman, relationship('Freshman', back_populates='packets'))
# The `lazy='subquery'` kwarg enables eager loading for signatures which makes signature calculations much faster
# See the docs here for details: https://docs.sqlalchemy.org/en/latest/orm/loading_relationships.html
- upper_signatures = relationship('UpperSignature', lazy='subquery',
- order_by='UpperSignature.signed.desc(), UpperSignature.updated')
- fresh_signatures = relationship('FreshSignature', lazy='subquery',
- order_by='FreshSignature.signed.desc(), FreshSignature.updated')
- misc_signatures = relationship('MiscSignature', lazy='subquery', order_by='MiscSignature.updated')
-
- def is_open(self):
+ upper_signatures = cast('UpperSignature', relationship('UpperSignature', lazy='subquery',
+ order_by='UpperSignature.signed.desc(), UpperSignature.updated'))
+ fresh_signatures = cast('FreshSignature', relationship('FreshSignature', lazy='subquery',
+ order_by='FreshSignature.signed.desc(), FreshSignature.updated'))
+ misc_signatures = cast('MiscSignature',
+ relationship('MiscSignature', lazy='subquery', order_by='MiscSignature.updated'))
+
+ def is_open(self) -> bool:
return self.start < datetime.now() < self.end
- def signatures_required(self):
+ def signatures_required(self) -> SigCounts:
"""
:return: A SigCounts instance with the fields set to the number of signatures received by this packet
"""
@@ -86,7 +88,7 @@ class Packet(db.Model):
return SigCounts(upper, fresh, REQUIRED_MISC_SIGNATURES)
- def signatures_received(self):
+ def signatures_received(self) -> SigCounts:
"""
:return: A SigCounts instance with the fields set to the number of required signatures for this packet
"""
@@ -95,7 +97,7 @@ class Packet(db.Model):
return SigCounts(upper, fresh, len(self.misc_signatures))
- def did_sign(self, username, is_csh):
+ def did_sign(self, username: str, is_csh: bool) -> bool:
"""
:param username: The CSH or RIT username to check for
:param is_csh: Set to True for CSH accounts and False for freshmen
@@ -114,21 +116,21 @@ class Packet(db.Model):
# The user must be a misc CSHer that hasn't signed this packet or an off-floor freshmen
return False
- def is_100(self):
+ def is_100(self) -> bool:
"""
Checks if this packet has reached 100%
"""
return self.signatures_required().total == self.signatures_received().total
@classmethod
- def open_packets(cls):
+ def open_packets(cls) -> list['Packet']:
"""
Helper method for fetching all currently open packets
"""
return cls.query.filter(cls.start < datetime.now(), cls.end > datetime.now()).all()
@classmethod
- def by_id(cls, packet_id):
+ def by_id(cls, packet_id: int) -> 'Packet':
"""
Helper method for fetching 1 packet by its id
"""
@@ -136,43 +138,43 @@ class Packet(db.Model):
class UpperSignature(db.Model):
__tablename__ = 'signature_upper'
- packet_id = Column(Integer, ForeignKey('packet.id'), primary_key=True)
- member = Column(String(36), primary_key=True)
- signed = Column(Boolean, default=False, nullable=False)
- eboard = Column(String(12), nullable=True)
- active_rtp = Column(Boolean, default=False, nullable=False)
- three_da = Column(Boolean, default=False, nullable=False)
- webmaster = Column(Boolean, default=False, nullable=False)
- c_m = Column(Boolean, default=False, nullable=False)
- w_m = Column(Boolean, default=False, nullable=False)
- drink_admin = Column(Boolean, default=False, nullable=False)
- updated = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
+ packet_id = cast(int, Column(Integer, ForeignKey('packet.id'), primary_key=True))
+ member = cast(str, Column(String(36), primary_key=True))
+ signed = cast(bool, Column(Boolean, default=False, nullable=False))
+ eboard = cast(Optional[str], Column(String(12), nullable=True))
+ active_rtp = cast(bool, Column(Boolean, default=False, nullable=False))
+ three_da = cast(bool, Column(Boolean, default=False, nullable=False))
+ webmaster = cast(bool, Column(Boolean, default=False, nullable=False))
+ c_m = cast(bool, Column(Boolean, default=False, nullable=False))
+ w_m = cast(bool, Column(Boolean, default=False, nullable=False))
+ drink_admin = cast(bool, Column(Boolean, default=False, nullable=False))
+ updated = cast(datetime, Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False))
- packet = relationship('Packet', back_populates='upper_signatures')
+ packet = cast(Packet, relationship('Packet', back_populates='upper_signatures'))
class FreshSignature(db.Model):
__tablename__ = 'signature_fresh'
- packet_id = Column(Integer, ForeignKey('packet.id'), primary_key=True)
- freshman_username = Column(ForeignKey('freshman.rit_username'), primary_key=True)
- signed = Column(Boolean, default=False, nullable=False)
- updated = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
+ packet_id = cast(int, Column(Integer, ForeignKey('packet.id'), primary_key=True))
+ freshman_username = cast(str, Column(ForeignKey('freshman.rit_username'), primary_key=True))
+ signed = cast(bool, Column(Boolean, default=False, nullable=False))
+ updated = cast(datetime, Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False))
- packet = relationship('Packet', back_populates='fresh_signatures')
- freshman = relationship('Freshman', back_populates='fresh_signatures')
+ packet = cast(Packet, relationship('Packet', back_populates='fresh_signatures'))
+ freshman = cast(Freshman, relationship('Freshman', back_populates='fresh_signatures'))
class MiscSignature(db.Model):
__tablename__ = 'signature_misc'
- packet_id = Column(Integer, ForeignKey('packet.id'), primary_key=True)
- member = Column(String(36), primary_key=True)
- updated = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
+ packet_id = cast(int, Column(Integer, ForeignKey('packet.id'), primary_key=True))
+ member = cast(str, Column(String(36), primary_key=True))
+ updated = cast(datetime, Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False))
- packet = relationship('Packet', back_populates='misc_signatures')
+ packet = cast(Packet, relationship('Packet', back_populates='misc_signatures'))
class NotificationSubscription(db.Model):
__tablename__ = 'notification_subscriptions'
- member = Column(String(36), nullable=True)
- freshman_username = Column(ForeignKey('freshman.rit_username'), nullable=True)
- token = Column(String(256), primary_key=True, nullable=False)
+ member = cast(str, Column(String(36), nullable=True))
+ freshman_username = cast(str, Column(ForeignKey('freshman.rit_username'), nullable=True))
+ token = cast(str, Column(String(256), primary_key=True, nullable=False))
diff --git a/packet/notifications.py b/packet/notifications.py
index 69e1ab6..49b89fb 100644
--- a/packet/notifications.py
+++ b/packet/notifications.py
@@ -1,7 +1,10 @@
+from datetime import datetime
+from typing import Any, Callable, TypeVar, cast
+
import onesignal_sdk.client as onesignal
from packet import app, intro_onesignal_client, csh_onesignal_client
-from packet.models import NotificationSubscription
+from packet.models import NotificationSubscription, Packet
post_body = {
'contents': {'en': 'Default message'},
@@ -11,22 +14,24 @@ post_body = {
'url': app.config['PROTOCOL'] + app.config['SERVER_NAME']
}
-def require_onesignal_intro(func):
- def require_onesignal_intro_wrapper(*args, **kwargs):
+WrappedFunc = TypeVar('WrappedFunc', bound=Callable)
+
+def require_onesignal_intro(func: WrappedFunc) -> WrappedFunc:
+ def require_onesignal_intro_wrapper(*args: list, **kwargs: dict) -> Any:
if intro_onesignal_client:
return func(*args, **kwargs)
return None
- return require_onesignal_intro_wrapper
+ return cast(WrappedFunc, require_onesignal_intro_wrapper)
-def require_onesignal_csh(func):
- def require_onesignal_csh_wrapper(*args, **kwargs):
+def require_onesignal_csh(func: WrappedFunc) -> WrappedFunc:
+ def require_onesignal_csh_wrapper(*args: list, **kwargs: dict) -> Any:
if csh_onesignal_client:
return func(*args, **kwargs)
return None
- return require_onesignal_csh_wrapper
+ return cast(WrappedFunc, require_onesignal_csh_wrapper)
-def send_notification(notification_body, subscriptions, client):
+def send_notification(notification_body: dict, subscriptions: list, client: onesignal.Client) -> None:
tokens = list(map(lambda subscription: subscription.token, subscriptions))
if tokens:
notification = onesignal.Notification(post_body=notification_body)
@@ -39,7 +44,7 @@ def send_notification(notification_body, subscriptions, client):
@require_onesignal_intro
-def packet_signed_notification(packet, signer):
+def packet_signed_notification(packet: Packet, signer: str) -> None:
subscriptions = NotificationSubscription.query.filter_by(freshman_username=packet.freshman_username)
if subscriptions:
notification_body = post_body
@@ -53,9 +58,10 @@ def packet_signed_notification(packet, signer):
@require_onesignal_csh
@require_onesignal_intro
-def packet_100_percent_notification(packet):
- member_subscriptions = NotificationSubscription.query.filter(NotificationSubscription.member.isnot(None))
- intro_subscriptions = NotificationSubscription.query.filter(NotificationSubscription.freshman_username.isnot(None))
+def packet_100_percent_notification(packet: Packet) -> None:
+ member_subscriptions = NotificationSubscription.query.filter(cast(Any, NotificationSubscription.member).isnot(None))
+ intro_subscriptions = NotificationSubscription.query.filter(
+ cast(Any, NotificationSubscription.freshman_username).isnot(None))
if member_subscriptions or intro_subscriptions:
notification_body = post_body
notification_body['contents']['en'] = packet.freshman.name + ' got 💯 on packet!'
@@ -68,7 +74,7 @@ def packet_100_percent_notification(packet):
@require_onesignal_intro
-def packet_starting_notification(packet):
+def packet_starting_notification(packet: Packet) -> None:
subscriptions = NotificationSubscription.query.filter_by(freshman_username=packet.freshman_username)
if subscriptions:
notification_body = post_body
@@ -81,8 +87,8 @@ def packet_starting_notification(packet):
@require_onesignal_csh
-def packets_starting_notification(start_date):
- member_subscriptions = NotificationSubscription.query.filter(NotificationSubscription.member.isnot(None))
+def packets_starting_notification(start_date: datetime) -> None:
+ member_subscriptions = NotificationSubscription.query.filter(cast(Any, NotificationSubscription.member).isnot(None))
if member_subscriptions:
notification_body = post_body
notification_body['contents']['en'] = 'New packets have started, visit packet to see them!'
diff --git a/packet/stats.py b/packet/stats.py
index 2ac5fb2..c6d6103 100644
--- a/packet/stats.py
+++ b/packet/stats.py
@@ -1,9 +1,35 @@
-from datetime import timedelta
+from datetime import date as dateType, timedelta
+from typing import TypedDict, Union, cast, Callable
from packet.models import Packet, MiscSignature, UpperSignature
+# Types
+class Freshman(TypedDict):
+ name: str
+ rit_username: str
-def packet_stats(packet_id):
+class WhoSigned(TypedDict):
+ upper: list[str]
+ misc: list[str]
+ fresh: list[str]
+
+class PacketStats(TypedDict):
+ packet_id: int
+ freshman: Freshman
+ dates: dict[str, dict[str, list[str]]]
+
+class SimplePacket(TypedDict):
+ id: int
+ freshman_username: str
+
+class SigDict(TypedDict):
+ date: dateType
+ packet: SimplePacket
+
+Stats = dict[dateType, list[str]]
+
+
+def packet_stats(packet_id: int) -> PacketStats:
"""
Gather statistics for a packet in the form of number of signatures per day
@@ -28,17 +54,17 @@ def packet_stats(packet_id):
print(dates)
- upper_stats = {date: list() for date in dates}
+ upper_stats: Stats = {date: list() for date in dates}
for uid, date in map(lambda sig: (sig.member, sig.updated),
filter(lambda sig: sig.signed, packet.upper_signatures)):
upper_stats[date.date()].append(uid)
- fresh_stats = {date: list() for date in dates}
+ fresh_stats: Stats = {date: list() for date in dates}
for username, date in map(lambda sig: (sig.freshman_username, sig.updated),
filter(lambda sig: sig.signed, packet.fresh_signatures)):
fresh_stats[date.date()].append(username)
- misc_stats = {date: list() for date in dates}
+ misc_stats: Stats = {date: list() for date in dates}
for uid, date in map(lambda sig: (sig.member, sig.updated), packet.misc_signatures):
misc_stats[date.date()].append(uid)
@@ -60,7 +86,7 @@ def packet_stats(packet_id):
}
-def sig2dict(sig):
+def sig2dict(sig: Union[UpperSignature, MiscSignature]) -> SigDict:
"""
A utility function for upperclassman stats.
Converts an UpperSignature to a dictionary with the date and the packet.
@@ -74,8 +100,11 @@ def sig2dict(sig):
},
}
+class UpperStats(TypedDict):
+ member: str
+ signatures: dict[str, list[SimplePacket]]
-def upperclassman_stats(uid):
+def upperclassman_stats(uid: str) -> UpperStats:
"""
Gather statistics for an upperclassman's signature habits
@@ -104,7 +133,7 @@ def upperclassman_stats(uid):
'signatures': {
date.isoformat() : list(
map(lambda sd: sd['packet'],
- filter(lambda sig, d=date: sig['date'] == d,
+ filter(cast(Callable, lambda sig, d=date: sig['date'] == d),
sig_dicts
)
)
diff --git a/packet/utils.py b/packet/utils.py
index ff2f5a6..ea4693b 100644
--- a/packet/utils.py
+++ b/packet/utils.py
@@ -3,6 +3,7 @@ General utilities and decorators for supporting the Python logic
"""
from datetime import datetime, time, timedelta, date
from functools import wraps, lru_cache
+from typing import Any, Callable, TypeVar, cast
import requests
from flask import session, redirect
@@ -14,15 +15,16 @@ from packet.notifications import packets_starting_notification, packet_starting_
INTRO_REALM = 'https://sso.csh.rit.edu/auth/realms/intro'
+WrappedFunc = TypeVar('WrappedFunc', bound=Callable)
-def before_request(func):
+def before_request(func: WrappedFunc) -> WrappedFunc:
"""
Credit to Liam Middlebrook and Ram Zallan
https://github.com/liam-middlebrook/gallery
"""
@wraps(func)
- def wrapped_function(*args, **kwargs):
+ def wrapped_function(*args: list, **kwargs: dict) -> Any:
uid = str(session['userinfo'].get('preferred_username', ''))
if session['id_token']['iss'] == INTRO_REALM:
info = {
@@ -43,11 +45,11 @@ def before_request(func):
kwargs['info'] = info
return func(*args, **kwargs)
- return wrapped_function
+ return cast(WrappedFunc, wrapped_function)
@lru_cache(maxsize=128)
-def is_freshman_on_floor(rit_username):
+def is_freshman_on_floor(rit_username: str) -> bool:
"""
Checks if a freshman is on floor
"""
@@ -58,14 +60,14 @@ def is_freshman_on_floor(rit_username):
return False
-def packet_auth(func):
+def packet_auth(func: WrappedFunc) -> WrappedFunc:
"""
Decorator for easily configuring oidc
"""
@auth.oidc_auth('app')
@wraps(func)
- def wrapped_function(*args, **kwargs):
+ def wrapped_function(*args: list, **kwargs: dict) -> Any:
if app.config['REALM'] == 'csh':
username = str(session['userinfo'].get('preferred_username', ''))
if ldap.is_intromember(ldap.get_member(username)):
@@ -74,17 +76,17 @@ def packet_auth(func):
return func(*args, **kwargs)
- return wrapped_function
+ return cast(WrappedFunc, wrapped_function)
-def admin_auth(func):
+def admin_auth(func: WrappedFunc) -> WrappedFunc:
"""
Decorator for easily configuring oidc
"""
@auth.oidc_auth('app')
@wraps(func)
- def wrapped_function(*args, **kwargs):
+ def wrapped_function(*args: list, **kwargs: dict) -> Any:
if app.config['REALM'] == 'csh':
username = str(session['userinfo'].get('preferred_username', ''))
member = ldap.get_member(username)
@@ -96,10 +98,10 @@ def admin_auth(func):
return func(*args, **kwargs)
- return wrapped_function
+ return cast(WrappedFunc, wrapped_function)
-def notify_slack(name: str):
+def notify_slack(name: str) -> None:
"""
Sends a congratulate on sight decree to Slack
"""
@@ -112,7 +114,7 @@ def notify_slack(name: str):
app.logger.info('Posted 100% notification to slack for ' + name)
-def sync_freshman(freshmen_list: dict):
+def sync_freshman(freshmen_list: dict) -> None:
freshmen_in_db = {freshman.rit_username: freshman for freshman in Freshman.query.all()}
for list_freshman in freshmen_list.values():
@@ -150,7 +152,7 @@ def sync_freshman(freshmen_list: dict):
db.session.commit()
-def create_new_packets(base_date: date, freshmen_list: dict):
+def create_new_packets(base_date: date, freshmen_list: dict) -> None:
packet_start_time = time(hour=19)
packet_end_time = time(hour=21)
start = datetime.combine(base_date, packet_start_time)
@@ -173,7 +175,7 @@ def create_new_packets(base_date: date, freshmen_list: dict):
# Create the new packets and the signatures for each freshman in the given CSV
print('Creating DB entries and sending emails...')
- for freshman in Freshman.query.filter(Freshman.rit_username.in_(freshmen_list)).all():
+ for freshman in Freshman.query.filter(cast(Any, Freshman.rit_username).in_(freshmen_list)).all():
packet = Packet(freshman=freshman, start=start, end=end)
db.session.add(packet)
send_start_packet_mail(packet)
@@ -197,7 +199,7 @@ def create_new_packets(base_date: date, freshmen_list: dict):
db.session.commit()
-def sync_with_ldap():
+def sync_with_ldap() -> None:
print('Fetching data from LDAP...')
all_upper = {member.uid: member for member in filter(
lambda member: not ldap.is_intromember(member) and not ldap.is_on_coop(member), ldap.get_active_members())}
diff --git a/requirements.txt b/requirements.txt
index cda0b3b..dcb44ae 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,16 @@
-csh_ldap~=2.3.1
-ddtrace
-Flask~=1.1.2
Flask-Gzip~=0.2
Flask-Mail~=0.9.1
Flask-Migrate~=2.7.0
Flask-pyoidc~=3.7.0
+Flask~=1.1.2
+csh_ldap~=2.3.1
+ddtrace
flask_sqlalchemy~=2.4.4
gunicorn~=20.0.4
+mypy
onesignal-sdk~=2.0.0
psycopg2-binary~=2.8.6
-pylint~=2.7.2
pylint-quotes~=0.2.1
+pylint~=2.7.2
sentry-sdk~=1.0.0
+sqlalchemy[mypy]
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..37543e3
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[mypy]
+plugins=sqlalchemy.ext.mypy.plugin