diff options
author | Max Meinhold <mxmeinhold@gmail.com> | 2021-03-28 00:17:55 -0400 |
---|---|---|
committer | Max Meinhold <mxmeinhold@gmail.com> | 2021-04-01 22:25:08 -0400 |
commit | 73e55ac8b0f2e58de681afac55ea5d38507c609e (patch) | |
tree | 5deed7ec23f678a984cac70e3cdd1d0ca804207a | |
parent | c6af9137de15bbe20362033b9e7e7ffecca055f1 (diff) |
Add type hints and mypy
-rw-r--r-- | .github/workflows/python-app.yml | 24 | ||||
-rw-r--r-- | README.md | 9 | ||||
-rw-r--r-- | packet/__init__.py | 2 | ||||
-rw-r--r-- | packet/commands.py | 26 | ||||
-rw-r--r-- | packet/context_processors.py | 15 | ||||
-rw-r--r-- | packet/git.py | 6 | ||||
-rw-r--r-- | packet/ldap.py | 55 | ||||
-rw-r--r-- | packet/log_utils.py | 16 | ||||
-rw-r--r-- | packet/mail.py | 12 | ||||
-rw-r--r-- | packet/models.py | 104 | ||||
-rw-r--r-- | packet/notifications.py | 36 | ||||
-rw-r--r-- | packet/stats.py | 45 | ||||
-rw-r--r-- | packet/utils.py | 32 | ||||
-rw-r--r-- | requirements.txt | 10 | ||||
-rw-r--r-- | setup.cfg | 2 |
15 files changed, 236 insertions, 158 deletions
diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 763c63e..3934f40 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -32,3 +32,27 @@ jobs: - name: Lint with pylint run: | pylint packet + + typecheck: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: [3.9] + + steps: + - name: Install ldap dependencies + run: sudo apt-get update && sudo apt-get install libldap2-dev libsasl2-dev + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Typecheck with mypy + run: | + # Disabled error codes to discard errors from imports + mypy --disable-error-code import --disable-error-code name-defined --disallow-untyped-defs --exclude routes packet @@ -115,13 +115,14 @@ All DB commands are from the `Flask-Migrate` library and are used to configure D docs [here](https://flask-migrate.readthedocs.io/en/latest/) for details. ## Code standards -This project is configured to use Pylint. Commits will be pylinted by GitHub actions and if the score drops your build will -fail blocking you from merging. To make your life easier just run it before making a PR. +This project is configured to use Pylint and mypy. Commits will be pylinted and typechecked by GitHub actions and if the +score drops your build will fail blocking you from merging. To make your life easier just run it before making a PR. -To run pylint use this command: +To run pylint and mypy use these commands: ```bash pylint packet/routes packet +mypy --disable-error-code import --disable-error-code name-defined --disallow-untyped-defs --exclude routes packet ``` All python files should have a top-level docstring explaining the contents of the file and complex functions should -have docstrings explaining any non-obvious portions. +have docstrings explaining any non-obvious portions. Functions should have type annotations. diff --git a/packet/__init__.py b/packet/__init__.py index ac6abd0..898429c 100644 --- a/packet/__init__.py +++ b/packet/__init__.py @@ -21,7 +21,7 @@ from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration from .git import get_version -app = Flask(__name__) +app: Flask = Flask(__name__) gzip = Gzip(app) # Load default configuration and any environment variable overrides diff --git a/packet/commands.py b/packet/commands.py index 32dac8e..ea3591a 100644 --- a/packet/commands.py +++ b/packet/commands.py @@ -5,7 +5,7 @@ Defines command-line utilities for use with packet import sys from secrets import token_hex -from datetime import datetime, time +from datetime import datetime, time, date import csv import click @@ -15,7 +15,7 @@ from .utils import sync_freshman, create_new_packets, sync_with_ldap @app.cli.command('create-secret') -def create_secret(): +def create_secret() -> None: """ Generates a securely random token. Useful for creating a value for use in the "SECRET_KEY" config setting. """ @@ -28,13 +28,13 @@ packet_end_time = time(hour=21) class CSVFreshman: - def __init__(self, row): + def __init__(self, row: list[str]) -> None: self.name = row[0].strip() self.rit_username = row[3].strip() self.onfloor = row[1].strip() == 'TRUE' -def parse_csv(freshmen_csv): +def parse_csv(freshmen_csv: str) -> dict[str, CSVFreshman]: print('Parsing file...') try: with open(freshmen_csv, newline='') as freshmen_csv_file: @@ -44,7 +44,7 @@ def parse_csv(freshmen_csv): raise e -def input_date(prompt): +def input_date(prompt: str) -> date: while True: try: date_str = input(prompt + ' (format: MM/DD/YYYY): ') @@ -55,7 +55,7 @@ def input_date(prompt): @app.cli.command('sync-freshmen') @click.argument('freshmen_csv') -def sync_freshmen(freshmen_csv): +def sync_freshmen(freshmen_csv: str) -> None: """ Updates the freshmen entries in the DB to match the given CSV. """ @@ -68,7 +68,7 @@ def sync_freshmen(freshmen_csv): @app.cli.command('create-packets') @click.argument('freshmen_csv') -def create_packets(freshmen_csv): +def create_packets(freshmen_csv: str) -> None: """ Creates a new packet season for each of the freshmen in the given CSV. """ @@ -84,7 +84,7 @@ def create_packets(freshmen_csv): @app.cli.command('ldap-sync') -def ldap_sync(): +def ldap_sync() -> None: """ Updates the upper and misc sigs in the DB to match ldap. """ @@ -97,7 +97,7 @@ def ldap_sync(): help='The file to write to. If no file provided, output is sent to stdout.') @click.option('--csv/--no-csv', 'use_csv', required=False, default=False, help='Format output as comma separated list.') @click.option('--date', 'date_str', required=False, default='', help='Packet end date in the format MM/DD/YYYY.') -def fetch_results(file_path, use_csv, date_str): +def fetch_results(file_path: str, use_csv: bool, date_str: str) -> None: """ Fetches and prints the results from a given packet season. """ @@ -150,7 +150,7 @@ def fetch_results(file_path, use_csv, date_str): @app.cli.command('extend-packet') @click.argument('packet_id') -def extend_packet(packet_id): +def extend_packet(packet_id: int) -> None: """ Extends the given packet by setting a new end date. """ @@ -168,7 +168,7 @@ def extend_packet(packet_id): print('Packet successfully extended') -def remove_sig(packet_id, username, is_member): +def remove_sig(packet_id: int, username: str, is_member: bool) -> None: packet = Packet.by_id(packet_id) if not packet.is_open(): @@ -200,7 +200,7 @@ def remove_sig(packet_id, username, is_member): @app.cli.command('remove-member-sig') @click.argument('packet_id') @click.argument('member') -def remove_member_sig(packet_id, member): +def remove_member_sig(packet_id: int, member: str) -> None: """ Removes the given member's signature from the given packet. :param member: The member's CSH username @@ -211,7 +211,7 @@ def remove_member_sig(packet_id, member): @app.cli.command('remove-freshman-sig') @click.argument('packet_id') @click.argument('freshman') -def remove_freshman_sig(packet_id, freshman): +def remove_freshman_sig(packet_id: int, freshman: str) -> None: """ Removes the given freshman's signature from the given packet. :param freshman: The freshman's RIT username diff --git a/packet/context_processors.py b/packet/context_processors.py index bff75b1..93ab115 100644 --- a/packet/context_processors.py +++ b/packet/context_processors.py @@ -5,14 +5,15 @@ import hashlib import urllib from functools import lru_cache from datetime import datetime +from typing import Callable -from packet.models import Freshman +from packet.models import Freshman, UpperSignature from packet import app, ldap # pylint: disable=bare-except @lru_cache(maxsize=128) -def get_csh_name(username): +def get_csh_name(username: str) -> str: try: member = ldap.get_member(username) return member.cn + ' (' + member.uid + ')' @@ -20,7 +21,7 @@ def get_csh_name(username): return username -def get_roles(sig): +def get_roles(sig: UpperSignature) -> dict[str, str]: """ Converts a signature's role fields to a dict for ease of access. :return: A dictionary of role short names to role long names @@ -45,7 +46,7 @@ def get_roles(sig): # pylint: disable=bare-except @lru_cache(maxsize=256) -def get_rit_name(username): +def get_rit_name(username: str) -> str: try: freshman = Freshman.query.filter_by(rit_username=username).first() return freshman.name + ' (' + username + ')' @@ -55,7 +56,7 @@ def get_rit_name(username): # pylint: disable=bare-except @lru_cache(maxsize=256) -def get_rit_image(username): +def get_rit_image(username: str) -> str: if username: addresses = [username + '@rit.edu', username + '@g.rit.edu'] for addr in addresses: @@ -69,7 +70,7 @@ def get_rit_image(username): return 'https://www.gravatar.com/avatar/freshmen?d=mp&f=y' -def log_time(label): +def log_time(label: str) -> None: """ Used during debugging to log timestamps while rendering templates """ @@ -77,7 +78,7 @@ def log_time(label): @app.context_processor -def utility_processor(): +def utility_processor() -> dict[str, Callable]: return dict( get_csh_name=get_csh_name, get_rit_name=get_rit_name, get_rit_image=get_rit_image, log_time=log_time, get_roles=get_roles diff --git a/packet/git.py b/packet/git.py index 00e4d65..506276d 100644 --- a/packet/git.py +++ b/packet/git.py @@ -2,7 +2,7 @@ import json import os import subprocess -def get_short_sha(commit_ish: str = 'HEAD'): +def get_short_sha(commit_ish: str = 'HEAD') -> str: """ Get the short hash of a commit-ish Returns '' if unfound @@ -14,7 +14,7 @@ def get_short_sha(commit_ish: str = 'HEAD'): except subprocess.CalledProcessError: return '' -def get_tag(commit_ish: str = 'HEAD'): +def get_tag(commit_ish: str = 'HEAD') -> str: """ Get the name of the tag at a given commit-ish Returns '' if untagged @@ -26,7 +26,7 @@ def get_tag(commit_ish: str = 'HEAD'): except subprocess.CalledProcessError: return '' -def get_version(commit_ish: str = 'HEAD'): +def get_version(commit_ish: str = 'HEAD') -> str: """ Get the version string of a commit-ish diff --git a/packet/ldap.py b/packet/ldap.py index 99b0367..f276484 100644 --- a/packet/ldap.py +++ b/packet/ldap.py @@ -4,8 +4,9 @@ Helper functions for working with the csh_ldap library from functools import lru_cache from datetime import date +from typing import Optional, cast, Any -from csh_ldap import CSHLDAP +from csh_ldap import CSHLDAP, CSHMember from packet import app @@ -20,32 +21,32 @@ class MockMember: self.cn = cn if cn else uid.title() # pylint: disable=invalid-name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.uid == other.uid return False - def __hash__(self): + def __hash__(self) -> int: return hash(self.uid) - def __repr__(self): + def __repr__(self) -> str: return f'MockMember(uid: {self.uid}, groups: {self.groups})' class LDAPWrapper: - def __init__(self, cshldap=None, mock_members=None): + def __init__(self, cshldap: CSHLDAP = None, mock_members: list[MockMember] = None): self.ldap = cshldap - self.mock_members = mock_members + self.mock_members = cast(list[MockMember], mock_members) if self.ldap: app.logger.info('LDAP configured with CSH LDAP') else: app.logger.info('LDAP configured with local mock') - def _get_group_members(self, group): + def _get_group_members(self, group: str) -> list[CSHMember]: """ :return: A list of CSHMember instances """ @@ -55,7 +56,7 @@ class LDAPWrapper: return list(filter(lambda member: group in member.groups, self.mock_members)) - def _is_member_of_group(self, member, group): + def _is_member_of_group(self, member: CSHMember, group: str) -> bool: """ :param member: A CSHMember instance """ @@ -67,7 +68,7 @@ class LDAPWrapper: else: return group in member.groups - def get_groups(self, member): + def get_groups(self, member: CSHMember) -> list[str]: if self.ldap: return list( map( @@ -89,7 +90,7 @@ class LDAPWrapper: # Getters @lru_cache(maxsize=256) - def get_member(self, username): + def get_member(self, username: str) -> CSHMember: """ :return: A CSHMember instance """ @@ -102,7 +103,7 @@ class LDAPWrapper: raise KeyError('Invalid Search Name') - def get_active_members(self): + def get_active_members(self) -> list[CSHMember]: """ Gets all current, dues-paying members :return: A list of CSHMember instances @@ -110,7 +111,7 @@ class LDAPWrapper: return self._get_group_members('active') - def get_intro_members(self): + def get_intro_members(self) -> list[CSHMember]: """ Gets all freshmen members :return: A list of CSHMember instances @@ -118,7 +119,7 @@ class LDAPWrapper: return self._get_group_members('intromembers') - def get_eboard(self): + def get_eboard(self) -> list[CSHMember]: """ Gets all voting members of eboard :return: A list of CSHMember instances @@ -132,7 +133,7 @@ class LDAPWrapper: return members - def get_live_onfloor(self): + def get_live_onfloor(self) -> list[CSHMember]: """ All upperclassmen who live on floor and are not eboard :return: A list of CSHMember instances @@ -146,7 +147,7 @@ class LDAPWrapper: return members - def get_active_rtps(self): + def get_active_rtps(self) -> list[CSHMember]: """ All active RTPs :return: A list of CSHMember instances @@ -154,7 +155,7 @@ class LDAPWrapper: return [member.uid for member in self._get_group_members('active_rtp')] - def get_3das(self): + def get_3das(self) -> list[CSHMember]: """ All 3das :return: A list of CSHMember instances @@ -162,7 +163,7 @@ class LDAPWrapper: return [member.uid for member in self._get_group_members('3da')] - def get_webmasters(self): + def get_webmasters(self) -> list[CSHMember]: """ All webmasters :return: A list of CSHMember instances @@ -170,14 +171,14 @@ class LDAPWrapper: return [member.uid for member in self._get_group_members('webmaster')] - def get_constitutional_maintainers(self): + def get_constitutional_maintainers(self) -> list[CSHMember]: """ All constitutional maintainers :return: A list of CSHMember instances """ return [member.uid for member in self._get_group_members('constitutional_maintainers')] - def get_wiki_maintainers(self): + def get_wiki_maintainers(self) -> list[CSHMember]: """ All wiki maintainers :return: A list of CSHMember instances @@ -185,7 +186,7 @@ class LDAPWrapper: return [member.uid for member in self._get_group_members('wiki_maintainers')] - def get_drink_admins(self): + def get_drink_admins(self) -> list[CSHMember]: """ All drink admins :return: A list of CSHMember instances @@ -193,7 +194,7 @@ class LDAPWrapper: return [member.uid for member in self._get_group_members('drink')] - def get_eboard_role(self, member): + def get_eboard_role(self, member: CSHMember) -> Optional[str]: """ :param member: A CSHMember instance :return: A String or None @@ -224,29 +225,29 @@ class LDAPWrapper: # Status checkers - def is_eboard(self, member): + def is_eboard(self, member: CSHMember) -> bool: """ :param member: A CSHMember instance """ return self._is_member_of_group(member, 'eboard') - def is_evals(self, member): + def is_evals(self, member: CSHMember) -> bool: return self._is_member_of_group(member, 'eboard-evaluations') - def is_rtp(self, member): + def is_rtp(self, member: CSHMember) -> bool: return self._is_member_of_group(member, 'rtp') - def is_intromember(self, member): + def is_intromember(self, member: CSHMember) -> bool: """ :param member: A CSHMember instance """ return self._is_member_of_group(member, 'intromembers') - def is_on_coop(self, member): + def is_on_coop(self, member: CSHMember) -> bool: """ :param member: A CSHMember instance """ @@ -256,7 +257,7 @@ class LDAPWrapper: return self._is_member_of_group(member, 'spring_coop') - def get_roomnumber(self, member): # pylint: disable=no-self-use + def get_roomnumber(self, member: CSHMember) -> Optional[int]: # pylint: disable=no-self-use """ :param member: A CSHMember instance """ diff --git a/packet/log_utils.py b/packet/log_utils.py index 5481bef..2d69f16 100644 --- a/packet/log_utils.py +++ b/packet/log_utils.py @@ -4,18 +4,20 @@ General utilities for logging metadata from functools import wraps from datetime import datetime +from typing import Any, Callable, TypeVar, cast from packet import app, ldap from packet.context_processors import get_rit_name from packet.utils import is_freshman_on_floor +WrappedFunc = TypeVar('WrappedFunc', bound=Callable) -def log_time(func): +def log_time(func: WrappedFunc) -> WrappedFunc: """ Decorator for logging the execution time of a function """ @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: list, **kwargs: dict) -> Any: start = datetime.now() result = func(*args, **kwargs) @@ -25,10 +27,10 @@ def log_time(func): return result - return wrapped_function + return cast(WrappedFunc, wrapped_function) -def _format_cache(func): +def _format_cache(func: Any) -> str: """ :return: The output of func.cache_info() as a compactly formatted string """ @@ -41,17 +43,17 @@ def _format_cache(func): _caches = (get_rit_name, ldap.get_member, is_freshman_on_floor) -def log_cache(func): +def log_cache(func: WrappedFunc) -> WrappedFunc: """ Decorator for logging cache info """ @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: list, **kwargs: dict) -> Any: result = func(*args, **kwargs) app.logger.info('Cache stats: ' + ', '.join(map(_format_cache, _caches))) return result - return wrapped_function + return cast(WrappedFunc, wrapped_function) diff --git a/packet/mail.py b/packet/mail.py index b5a4f12..5aa32f5 100644 --- a/packet/mail.py +++ b/packet/mail.py @@ -1,12 +1,19 @@ +from typing import TypedDict + from flask import render_template from flask_mail import Mail, Message from packet import app +from packet.models import Packet mail = Mail(app) -def send_start_packet_mail(packet): +class ReportForm(TypedDict): + person: str + report: str + +def send_start_packet_mail(packet: Packet) -> None: if app.config['MAIL_PROD']: recipients = ['<' + packet.freshman.rit_username + '@rit.edu>'] msg = Message(subject='CSH Packet Starts ' + packet.start.strftime('%A, %B %-d'), @@ -19,8 +26,7 @@ def send_start_packet_mail(packet): app.logger.info('Sending mail to ' + recipients[0]) mail.send(msg) - -def send_report_mail(form_results, reporter): +def send_report_mail(form_results: ReportForm, reporter: str) -> None: if app.config['MAIL_PROD']: recipients = ['<evals@csh.rit.edu>'] msg = Message(subject='Packet Report', diff --git a/packet/models.py b/packet/models.py index b914d27..f22e467 100644 --- a/packet/models.py +++ b/packet/models.py @@ -4,6 +4,7 @@ Defines the application's database models from datetime import datetime from itertools import chain +from typing import cast, Optional from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Boolean from sqlalchemy.orm import relationship @@ -18,7 +19,7 @@ class SigCounts: """ Utility class for returning counts of signatures broken out by type """ - def __init__(self, upper, fresh, misc): + def __init__(self, upper: int, fresh: int, misc: int): # Base fields self.upper = upper self.fresh = fresh @@ -34,23 +35,23 @@ class SigCounts: class Freshman(db.Model): __tablename__ = 'freshman' - rit_username = Column(String(10), primary_key=True) - name = Column(String(64), nullable=False) - onfloor = Column(Boolean, nullable=False) - fresh_signatures = relationship('FreshSignature') + rit_username = cast(str, Column(String(10), primary_key=True)) + name = cast(str, Column(String(64), nullable=False)) + onfloor = cast(bool, Column(Boolean, nullable=False)) + fresh_signatures = cast('FreshSignature', relationship('FreshSignature')) # One freshman can have multiple packets if they repeat the intro process - packets = relationship('Packet', order_by='desc(Packet.id)') + packets = cast('Packet', relationship('Packet', order_by='desc(Packet.id)')) @classmethod - def by_username(cls, username: str): + def by_username(cls, username: str) -> 'Packet': """ Helper method to retrieve a freshman by their RIT username """ return cls.query.filter_by(rit_username=username).first() @classmethod - def get_all(cls): + def get_all(cls) -> list['Packet']: """ Helper method to get all freshmen easily """ @@ -59,25 +60,26 @@ class Freshman(db.Model): class Packet(db.Model): __tablename__ = 'packet' - id = Column(Integer, primary_key=True, autoincrement=True) - freshman_username = Column(ForeignKey('freshman.rit_username')) - start = Column(DateTime, nullable=False) - end = Column(DateTime, nullable=False) + id = cast(int, Column(Integer, primary_key=True, autoincrement=True)) + freshman_username = cast(str, Column(ForeignKey('freshman.rit_username'))) + start = cast(datetime, Column(DateTime, nullable=False)) + end = cast(datetime, Column(DateTime, nullable=False)) - freshman = relationship('Freshman', back_populates='packets') + freshman = cast(Freshman, relationship('Freshman', back_populates='packets')) # The `lazy='subquery'` kwarg enables eager loading for signatures which makes signature calculations much faster # See the docs here for details: https://docs.sqlalchemy.org/en/latest/orm/loading_relationships.html - upper_signatures = relationship('UpperSignature', lazy='subquery', - order_by='UpperSignature.signed.desc(), UpperSignature.updated') - fresh_signatures = relationship('FreshSignature', lazy='subquery', - order_by='FreshSignature.signed.desc(), FreshSignature.updated') - misc_signatures = relationship('MiscSignature', lazy='subquery', order_by='MiscSignature.updated') - - def is_open(self): + upper_signatures = cast('UpperSignature', relationship('UpperSignature', lazy='subquery', + order_by='UpperSignature.signed.desc(), UpperSignature.updated')) + fresh_signatures = cast('FreshSignature', relationship('FreshSignature', lazy='subquery', + order_by='FreshSignature.signed.desc(), FreshSignature.updated')) + misc_signatures = cast('MiscSignature', + relationship('MiscSignature', lazy='subquery', order_by='MiscSignature.updated')) + + def is_open(self) -> bool: return self.start < datetime.now() < self.end - def signatures_required(self): + def signatures_required(self) -> SigCounts: """ :return: A SigCounts instance with the fields set to the number of signatures received by this packet """ @@ -86,7 +88,7 @@ class Packet(db.Model): return SigCounts(upper, fresh, REQUIRED_MISC_SIGNATURES) - def signatures_received(self): + def signatures_received(self) -> SigCounts: """ :return: A SigCounts instance with the fields set to the number of required signatures for this packet """ @@ -95,7 +97,7 @@ class Packet(db.Model): return SigCounts(upper, fresh, len(self.misc_signatures)) - def did_sign(self, username, is_csh): + def did_sign(self, username: str, is_csh: bool) -> bool: """ :param username: The CSH or RIT username to check for :param is_csh: Set to True for CSH accounts and False for freshmen @@ -114,21 +116,21 @@ class Packet(db.Model): # The user must be a misc CSHer that hasn't signed this packet or an off-floor freshmen return False - def is_100(self): + def is_100(self) -> bool: """ Checks if this packet has reached 100% """ return self.signatures_required().total == self.signatures_received().total @classmethod - def open_packets(cls): + def open_packets(cls) -> list['Packet']: """ Helper method for fetching all currently open packets """ return cls.query.filter(cls.start < datetime.now(), cls.end > datetime.now()).all() @classmethod - def by_id(cls, packet_id): + def by_id(cls, packet_id: int) -> 'Packet': """ Helper method for fetching 1 packet by its id """ @@ -136,43 +138,43 @@ class Packet(db.Model): class UpperSignature(db.Model): __tablename__ = 'signature_upper' - packet_id = Column(Integer, ForeignKey('packet.id'), primary_key=True) - member = Column(String(36), primary_key=True) - signed = Column(Boolean, default=False, nullable=False) - eboard = Column(String(12), nullable=True) - active_rtp = Column(Boolean, default=False, nullable=False) - three_da = Column(Boolean, default=False, nullable=False) - webmaster = Column(Boolean, default=False, nullable=False) - c_m = Column(Boolean, default=False, nullable=False) - w_m = Column(Boolean, default=False, nullable=False) - drink_admin = Column(Boolean, default=False, nullable=False) - updated = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) + packet_id = cast(int, Column(Integer, ForeignKey('packet.id'), primary_key=True)) + member = cast(str, Column(String(36), primary_key=True)) + signed = cast(bool, Column(Boolean, default=False, nullable=False)) + eboard = cast(Optional[str], Column(String(12), nullable=True)) + active_rtp = cast(bool, Column(Boolean, default=False, nullable=False)) + three_da = cast(bool, Column(Boolean, default=False, nullable=False)) + webmaster = cast(bool, Column(Boolean, default=False, nullable=False)) + c_m = cast(bool, Column(Boolean, default=False, nullable=False)) + w_m = cast(bool, Column(Boolean, default=False, nullable=False)) + drink_admin = cast(bool, Column(Boolean, default=False, nullable=False)) + updated = cast(datetime, Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)) - packet = relationship('Packet', back_populates='upper_signatures') + packet = cast(Packet, relationship('Packet', back_populates='upper_signatures')) class FreshSignature(db.Model): __tablename__ = 'signature_fresh' - packet_id = Column(Integer, ForeignKey('packet.id'), primary_key=True) - freshman_username = Column(ForeignKey('freshman.rit_username'), primary_key=True) - signed = Column(Boolean, default=False, nullable=False) - updated = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) + packet_id = cast(int, Column(Integer, ForeignKey('packet.id'), primary_key=True)) + freshman_username = cast(str, Column(ForeignKey('freshman.rit_username'), primary_key=True)) + signed = cast(bool, Column(Boolean, default=False, nullable=False)) + updated = cast(datetime, Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)) - packet = relationship('Packet', back_populates='fresh_signatures') - freshman = relationship('Freshman', back_populates='fresh_signatures') + packet = cast(Packet, relationship('Packet', back_populates='fresh_signatures')) + freshman = cast(Freshman, relationship('Freshman', back_populates='fresh_signatures')) class MiscSignature(db.Model): __tablename__ = 'signature_misc' - packet_id = Column(Integer, ForeignKey('packet.id'), primary_key=True) - member = Column(String(36), primary_key=True) - updated = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) + packet_id = cast(int, Column(Integer, ForeignKey('packet.id'), primary_key=True)) + member = cast(str, Column(String(36), primary_key=True)) + updated = cast(datetime, Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)) - packet = relationship('Packet', back_populates='misc_signatures') + packet = cast(Packet, relationship('Packet', back_populates='misc_signatures')) class NotificationSubscription(db.Model): __tablename__ = 'notification_subscriptions' - member = Column(String(36), nullable=True) - freshman_username = Column(ForeignKey('freshman.rit_username'), nullable=True) - token = Column(String(256), primary_key=True, nullable=False) + member = cast(str, Column(String(36), nullable=True)) + freshman_username = cast(str, Column(ForeignKey('freshman.rit_username'), nullable=True)) + token = cast(str, Column(String(256), primary_key=True, nullable=False)) diff --git a/packet/notifications.py b/packet/notifications.py index 69e1ab6..49b89fb 100644 --- a/packet/notifications.py +++ b/packet/notifications.py @@ -1,7 +1,10 @@ +from datetime import datetime +from typing import Any, Callable, TypeVar, cast + import onesignal_sdk.client as onesignal from packet import app, intro_onesignal_client, csh_onesignal_client -from packet.models import NotificationSubscription +from packet.models import NotificationSubscription, Packet post_body = { 'contents': {'en': 'Default message'}, @@ -11,22 +14,24 @@ post_body = { 'url': app.config['PROTOCOL'] + app.config['SERVER_NAME'] } -def require_onesignal_intro(func): - def require_onesignal_intro_wrapper(*args, **kwargs): +WrappedFunc = TypeVar('WrappedFunc', bound=Callable) + +def require_onesignal_intro(func: WrappedFunc) -> WrappedFunc: + def require_onesignal_intro_wrapper(*args: list, **kwargs: dict) -> Any: if intro_onesignal_client: return func(*args, **kwargs) return None - return require_onesignal_intro_wrapper + return cast(WrappedFunc, require_onesignal_intro_wrapper) -def require_onesignal_csh(func): - def require_onesignal_csh_wrapper(*args, **kwargs): +def require_onesignal_csh(func: WrappedFunc) -> WrappedFunc: + def require_onesignal_csh_wrapper(*args: list, **kwargs: dict) -> Any: if csh_onesignal_client: return func(*args, **kwargs) return None - return require_onesignal_csh_wrapper + return cast(WrappedFunc, require_onesignal_csh_wrapper) -def send_notification(notification_body, subscriptions, client): +def send_notification(notification_body: dict, subscriptions: list, client: onesignal.Client) -> None: tokens = list(map(lambda subscription: subscription.token, subscriptions)) if tokens: notification = onesignal.Notification(post_body=notification_body) @@ -39,7 +44,7 @@ def send_notification(notification_body, subscriptions, client): @require_onesignal_intro -def packet_signed_notification(packet, signer): +def packet_signed_notification(packet: Packet, signer: str) -> None: subscriptions = NotificationSubscription.query.filter_by(freshman_username=packet.freshman_username) if subscriptions: notification_body = post_body @@ -53,9 +58,10 @@ def packet_signed_notification(packet, signer): @require_onesignal_csh @require_onesignal_intro -def packet_100_percent_notification(packet): - member_subscriptions = NotificationSubscription.query.filter(NotificationSubscription.member.isnot(None)) - intro_subscriptions = NotificationSubscription.query.filter(NotificationSubscription.freshman_username.isnot(None)) +def packet_100_percent_notification(packet: Packet) -> None: + member_subscriptions = NotificationSubscription.query.filter(cast(Any, NotificationSubscription.member).isnot(None)) + intro_subscriptions = NotificationSubscription.query.filter( + cast(Any, NotificationSubscription.freshman_username).isnot(None)) if member_subscriptions or intro_subscriptions: notification_body = post_body notification_body['contents']['en'] = packet.freshman.name + ' got 💯 on packet!' @@ -68,7 +74,7 @@ def packet_100_percent_notification(packet): @require_onesignal_intro -def packet_starting_notification(packet): +def packet_starting_notification(packet: Packet) -> None: subscriptions = NotificationSubscription.query.filter_by(freshman_username=packet.freshman_username) if subscriptions: notification_body = post_body @@ -81,8 +87,8 @@ def packet_starting_notification(packet): @require_onesignal_csh -def packets_starting_notification(start_date): - member_subscriptions = NotificationSubscription.query.filter(NotificationSubscription.member.isnot(None)) +def packets_starting_notification(start_date: datetime) -> None: + member_subscriptions = NotificationSubscription.query.filter(cast(Any, NotificationSubscription.member).isnot(None)) if member_subscriptions: notification_body = post_body notification_body['contents']['en'] = 'New packets have started, visit packet to see them!' diff --git a/packet/stats.py b/packet/stats.py index 2ac5fb2..c6d6103 100644 --- a/packet/stats.py +++ b/packet/stats.py @@ -1,9 +1,35 @@ -from datetime import timedelta +from datetime import date as dateType, timedelta +from typing import TypedDict, Union, cast, Callable from packet.models import Packet, MiscSignature, UpperSignature +# Types +class Freshman(TypedDict): + name: str + rit_username: str -def packet_stats(packet_id): +class WhoSigned(TypedDict): + upper: list[str] + misc: list[str] + fresh: list[str] + +class PacketStats(TypedDict): + packet_id: int + freshman: Freshman + dates: dict[str, dict[str, list[str]]] + +class SimplePacket(TypedDict): + id: int + freshman_username: str + +class SigDict(TypedDict): + date: dateType + packet: SimplePacket + +Stats = dict[dateType, list[str]] + + +def packet_stats(packet_id: int) -> PacketStats: """ Gather statistics for a packet in the form of number of signatures per day @@ -28,17 +54,17 @@ def packet_stats(packet_id): print(dates) - upper_stats = {date: list() for date in dates} + upper_stats: Stats = {date: list() for date in dates} for uid, date in map(lambda sig: (sig.member, sig.updated), filter(lambda sig: sig.signed, packet.upper_signatures)): upper_stats[date.date()].append(uid) - fresh_stats = {date: list() for date in dates} + fresh_stats: Stats = {date: list() for date in dates} for username, date in map(lambda sig: (sig.freshman_username, sig.updated), filter(lambda sig: sig.signed, packet.fresh_signatures)): fresh_stats[date.date()].append(username) - misc_stats = {date: list() for date in dates} + misc_stats: Stats = {date: list() for date in dates} for uid, date in map(lambda sig: (sig.member, sig.updated), packet.misc_signatures): misc_stats[date.date()].append(uid) @@ -60,7 +86,7 @@ def packet_stats(packet_id): } -def sig2dict(sig): +def sig2dict(sig: Union[UpperSignature, MiscSignature]) -> SigDict: """ A utility function for upperclassman stats. Converts an UpperSignature to a dictionary with the date and the packet. @@ -74,8 +100,11 @@ def sig2dict(sig): }, } +class UpperStats(TypedDict): + member: str + signatures: dict[str, list[SimplePacket]] -def upperclassman_stats(uid): +def upperclassman_stats(uid: str) -> UpperStats: """ Gather statistics for an upperclassman's signature habits @@ -104,7 +133,7 @@ def upperclassman_stats(uid): 'signatures': { date.isoformat() : list( map(lambda sd: sd['packet'], - filter(lambda sig, d=date: sig['date'] == d, + filter(cast(Callable, lambda sig, d=date: sig['date'] == d), sig_dicts ) ) diff --git a/packet/utils.py b/packet/utils.py index ff2f5a6..ea4693b 100644 --- a/packet/utils.py +++ b/packet/utils.py @@ -3,6 +3,7 @@ General utilities and decorators for supporting the Python logic """ from datetime import datetime, time, timedelta, date from functools import wraps, lru_cache +from typing import Any, Callable, TypeVar, cast import requests from flask import session, redirect @@ -14,15 +15,16 @@ from packet.notifications import packets_starting_notification, packet_starting_ INTRO_REALM = 'https://sso.csh.rit.edu/auth/realms/intro' +WrappedFunc = TypeVar('WrappedFunc', bound=Callable) -def before_request(func): +def before_request(func: WrappedFunc) -> WrappedFunc: """ Credit to Liam Middlebrook and Ram Zallan https://github.com/liam-middlebrook/gallery """ @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: list, **kwargs: dict) -> Any: uid = str(session['userinfo'].get('preferred_username', '')) if session['id_token']['iss'] == INTRO_REALM: info = { @@ -43,11 +45,11 @@ def before_request(func): kwargs['info'] = info return func(*args, **kwargs) - return wrapped_function + return cast(WrappedFunc, wrapped_function) @lru_cache(maxsize=128) -def is_freshman_on_floor(rit_username): +def is_freshman_on_floor(rit_username: str) -> bool: """ Checks if a freshman is on floor """ @@ -58,14 +60,14 @@ def is_freshman_on_floor(rit_username): return False -def packet_auth(func): +def packet_auth(func: WrappedFunc) -> WrappedFunc: """ Decorator for easily configuring oidc """ @auth.oidc_auth('app') @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: list, **kwargs: dict) -> Any: if app.config['REALM'] == 'csh': username = str(session['userinfo'].get('preferred_username', '')) if ldap.is_intromember(ldap.get_member(username)): @@ -74,17 +76,17 @@ def packet_auth(func): return func(*args, **kwargs) - return wrapped_function + return cast(WrappedFunc, wrapped_function) -def admin_auth(func): +def admin_auth(func: WrappedFunc) -> WrappedFunc: """ Decorator for easily configuring oidc """ @auth.oidc_auth('app') @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: list, **kwargs: dict) -> Any: if app.config['REALM'] == 'csh': username = str(session['userinfo'].get('preferred_username', '')) member = ldap.get_member(username) @@ -96,10 +98,10 @@ def admin_auth(func): return func(*args, **kwargs) - return wrapped_function + return cast(WrappedFunc, wrapped_function) -def notify_slack(name: str): +def notify_slack(name: str) -> None: """ Sends a congratulate on sight decree to Slack """ @@ -112,7 +114,7 @@ def notify_slack(name: str): app.logger.info('Posted 100% notification to slack for ' + name) -def sync_freshman(freshmen_list: dict): +def sync_freshman(freshmen_list: dict) -> None: freshmen_in_db = {freshman.rit_username: freshman for freshman in Freshman.query.all()} for list_freshman in freshmen_list.values(): @@ -150,7 +152,7 @@ def sync_freshman(freshmen_list: dict): db.session.commit() -def create_new_packets(base_date: date, freshmen_list: dict): +def create_new_packets(base_date: date, freshmen_list: dict) -> None: packet_start_time = time(hour=19) packet_end_time = time(hour=21) start = datetime.combine(base_date, packet_start_time) @@ -173,7 +175,7 @@ def create_new_packets(base_date: date, freshmen_list: dict): # Create the new packets and the signatures for each freshman in the given CSV print('Creating DB entries and sending emails...') - for freshman in Freshman.query.filter(Freshman.rit_username.in_(freshmen_list)).all(): + for freshman in Freshman.query.filter(cast(Any, Freshman.rit_username).in_(freshmen_list)).all(): packet = Packet(freshman=freshman, start=start, end=end) db.session.add(packet) send_start_packet_mail(packet) @@ -197,7 +199,7 @@ def create_new_packets(base_date: date, freshmen_list: dict): db.session.commit() -def sync_with_ldap(): +def sync_with_ldap() -> None: print('Fetching data from LDAP...') all_upper = {member.uid: member for member in filter( lambda member: not ldap.is_intromember(member) and not ldap.is_on_coop(member), ldap.get_active_members())} diff --git a/requirements.txt b/requirements.txt index cda0b3b..dcb44ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,16 @@ -csh_ldap~=2.3.1 -ddtrace -Flask~=1.1.2 Flask-Gzip~=0.2 Flask-Mail~=0.9.1 Flask-Migrate~=2.7.0 Flask-pyoidc~=3.7.0 +Flask~=1.1.2 +csh_ldap~=2.3.1 +ddtrace flask_sqlalchemy~=2.4.4 gunicorn~=20.0.4 +mypy onesignal-sdk~=2.0.0 psycopg2-binary~=2.8.6 -pylint~=2.7.2 pylint-quotes~=0.2.1 +pylint~=2.7.2 sentry-sdk~=1.0.0 +sqlalchemy[mypy] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..37543e3 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[mypy] +plugins=sqlalchemy.ext.mypy.plugin |