summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGalen Guyer <galen@galenguyer.com>2023-03-22 22:12:07 -0400
committerGalen Guyer <galen@galenguyer.com>2023-03-22 22:12:07 -0400
commit4b012474dcab1a2573605eb114b643c4d7aa781c (patch)
treefd9f3e9c0d239d7a310de24f0e238cbde7a91a2a
quick flask api
-rw-r--r--.gitignore5
-rw-r--r--app.py1
-rw-r--r--config.py21
-rw-r--r--dysphoria/__init__.py9
-rw-r--r--dysphoria/errors.py12
-rw-r--r--dysphoria/routes.py53
-rw-r--r--dysphoria/whisper.py9
-rw-r--r--requirements.txt5
-rw-r--r--transcribe.py32
-rw-r--r--wsgi.py8
10 files changed, 155 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..39bbdaa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+__pycache__/
+.vscode/
+audio/
+models/
+venv/
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..f2b470a
--- /dev/null
+++ b/app.py
@@ -0,0 +1 @@
+from dysphoria import app
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..9475d5e
--- /dev/null
+++ b/config.py
@@ -0,0 +1,21 @@
+import os
+import secrets
+from os.path import join, dirname
+from dotenv import load_dotenv
+
+basedir = os.path.abspath(os.path.dirname(__file__))
+dotenv_path = join(dirname(__file__), ".env")
+load_dotenv(dotenv_path)
+
+
+class Config(object):
+ IP = os.environ.get("DYSPHORIA_IP") or "0.0.0.0"
+ PORT = os.environ.get("DYSPHORIA_PORT") or 5000
+ SECRET_KEY = os.environ.get("DYSPHORIA_SECRET_KEY") or "".join(
+ secrets.token_hex(16)
+ )
+ API_TOKEN = os.environ.get("DYSPHORIA_API_TOKEN") or "".join(secrets.token_hex(16))
+ WHISPER_MODEL_PATH = (
+ os.environ.get("DYSPHORIA_WHISPER_MODEL_PATH") or "models/whisper-large-v2-ct2/"
+ )
+ WHISPER_DEVICE = os.environ.get("DYSPHORIA_WHISPER_DEVICE") or "cpu"
diff --git a/dysphoria/__init__.py b/dysphoria/__init__.py
new file mode 100644
index 0000000..f63fc2f
--- /dev/null
+++ b/dysphoria/__init__.py
@@ -0,0 +1,9 @@
+from flask import Flask
+from config import Config
+
+app = Flask(__name__)
+app.config.from_object(Config)
+
+print(f"api token: {app.config['API_TOKEN']}")
+
+from dysphoria import routes, errors
diff --git a/dysphoria/errors.py b/dysphoria/errors.py
new file mode 100644
index 0000000..f0de77d
--- /dev/null
+++ b/dysphoria/errors.py
@@ -0,0 +1,12 @@
+from dysphoria import app
+from flask import request, jsonify
+
+
+@app.errorhandler(400)
+def _error_400(error):
+ return jsonify({"error": "bad request", "details": str(error)}), 400
+
+
+@app.errorhandler(404)
+def _error_404(error):
+ return jsonify({"error": "not found"}), 404
diff --git a/dysphoria/routes.py b/dysphoria/routes.py
new file mode 100644
index 0000000..a070497
--- /dev/null
+++ b/dysphoria/routes.py
@@ -0,0 +1,53 @@
+from dysphoria import app, whisper
+from flask import jsonify, request
+
+
+@app.route("/api/v1/health", methods=["GET"])
+def _get_health_v1():
+ return jsonify({"status": "ok"}), 200
+
+
+@app.route("/api/v1/transcribe", methods=["POST"])
+def _post_transcribe_v1():
+ # get the api token from the auth header
+ if (
+ not request.headers.get("Authorization")
+ or request.headers.get("Authorization") != f"Bearer {app.config['API_TOKEN']}"
+ ):
+ return jsonify({"error": "unauthorized"}), 401
+
+ if request.form.get("prompt"):
+ prompt = request.form.get("prompt")
+ else:
+ prompt = None
+
+ try:
+ segments, info = whisper.whisper.transcribe(
+ request.files["file"], language="en", initial_prompt=prompt
+ )
+ except:
+ return jsonify({"error": "bad request"}), 400
+
+ segments = [
+ {
+ "id": i,
+ "start": segment.start,
+ "end": segment.end,
+ "text": segment.text.strip(),
+ }
+ for i, segment in enumerate(segments)
+ ]
+ text = " ".join([segment["text"] for segment in segments])
+
+ # return the transcription
+ return (
+ jsonify(
+ {
+ "text": text,
+ "language": info.language,
+ "duration": info.duration,
+ "segments": segments,
+ }
+ ),
+ 200,
+ )
diff --git a/dysphoria/whisper.py b/dysphoria/whisper.py
new file mode 100644
index 0000000..c95916f
--- /dev/null
+++ b/dysphoria/whisper.py
@@ -0,0 +1,9 @@
+from faster_whisper import WhisperModel
+from dysphoria import app
+
+whisper = WhisperModel(
+ app.config["WHISPER_MODEL_PATH"],
+ device=app.config["WHISPER_DEVICE"],
+ compute_type="default",
+ num_workers=4,
+)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..ebfa436
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,5 @@
+faster-whisper[conversion]
+
+flask
+gunicorn
+python-dotenv
diff --git a/transcribe.py b/transcribe.py
new file mode 100644
index 0000000..f7e1e9d
--- /dev/null
+++ b/transcribe.py
@@ -0,0 +1,32 @@
+import time
+from faster_whisper import WhisperModel
+
+model_path = "models/whisper-tiny-en-ct2/"
+audio_files = [
+ "1670546132-2421.m4a",
+ "1670362801-1654.m4a",
+ "1675533069-1077.m4a",
+ "1675540057-3070.m4a",
+ "1677179024-2421.m4a",
+ "1677179036-2421.m4a",
+ "1677289865-3070.m4a",
+ "1677289881-3070.m4a",
+ "us.ny.monroe-1654-1678728684.m4a",
+ "us.ny.monroe-1704-1677730882.m4a",
+]
+
+t0 = time.perf_counter()
+model = WhisperModel(model_path, device="cpu", compute_type="default", num_workers=4)
+t1 = time.perf_counter()
+print(f"loaded model in {t1 - t0:0.2f} seconds")
+
+for file in audio_files:
+ t1 = time.perf_counter()
+ segments, info = model.transcribe(f"audio/{file}", language="en")
+
+ for segment in segments:
+ print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
+ pass
+
+ t2 = time.perf_counter()
+ print(f"transcription finished in {t2 - t1:0.2f} seconds")
diff --git a/wsgi.py b/wsgi.py
new file mode 100644
index 0000000..42482be
--- /dev/null
+++ b/wsgi.py
@@ -0,0 +1,8 @@
+"""
+Primary entry point for the app
+"""
+
+from dysphoria import app
+
+if __name__ == "__main__":
+ app.run(host=app.config["IP"], port=int(app.config["PORT"]))