diff --git a/src/client.py b/src/client.py index 896ac69..f6c12a7 100644 --- a/src/client.py +++ b/src/client.py @@ -57,3 +57,11 @@ class Client(ABC): class NonStreamable(Exception): pass + + +class MissingCredentials(Exception): + pass + + +class AuthenticationError(Exception): + pass diff --git a/src/core.py b/src/core.py index b4a25a6..cf2bf46 100644 --- a/src/core.py +++ b/src/core.py @@ -50,6 +50,7 @@ from .clients import ( ) from .config import Config from .exceptions import DeezloaderFallback +from .media import Media from .user_paths import DB_PATH, FAILED_DB_PATH from .utils import extract_deezer_dynamic_link, extract_interpreter_url from .validation_regexps import ( @@ -64,14 +65,6 @@ from .validation_regexps import ( logger = logging.getLogger("streamrip") # ---------------- Constants ------------------ # -Media = Union[ - Type[Album], - Type[Playlist], - Type[Artist], - Type[Track], - Type[Label], - Type[Video], -] MEDIA_CLASS: Dict[str, Media] = { "album": Album, "playlist": Playlist, @@ -671,7 +664,7 @@ class RipCore(list): self.append(pl) - def handle_txt(self, filepath: Union[str, os.PathLike]): + def handle_txt(self, filepath: str): """ Handle a text file containing URLs. Lines starting with `#` are ignored. diff --git a/src/db.py b/src/db.py index e5cf331..21994c3 100644 --- a/src/db.py +++ b/src/db.py @@ -3,7 +3,6 @@ import logging import os import sqlite3 -from typing import Tuple, Union logger = logging.getLogger("streamrip") @@ -76,7 +75,7 @@ class Database: return bool(conn.execute(command, tuple(items.values())).fetchone()[0]) - def __contains__(self, keys: Union[str, dict]) -> bool: + def __contains__(self, keys: str | dict) -> bool: """Check whether a key-value pair exists in the database. :param keys: Either a dict with the structure {key: value_to_search_for, ...}, @@ -96,7 +95,7 @@ class Database: raise TypeError(keys) - def add(self, items: Union[str, Tuple[str]]): + def add(self, items: str | tuple[str]): """Add a row to the table. :param items: Column-name + value. Values must be provided for all cols. diff --git a/src/deezer_client.py b/src/deezer_client.py index 83445be..b88ead7 100644 --- a/src/deezer_client.py +++ b/src/deezer_client.py @@ -1,11 +1,23 @@ +import binascii +import hashlib + +import deezer +from Cryptodome.Cipher import AES + +from .client import AuthenticationError, Client, MissingCredentials, NonStreamable +from .config import Config +from .downloadable import DeezerDownloadable + + class DeezerClient(Client): source = "deezer" max_quality = 2 def __init__(self, config: Config): + self.global_config = config self.client = deezer.Deezer() self.logged_in = False - self.config = config.deezer + self.config = config.session.deezer async def login(self): arl = self.config.arl @@ -16,13 +28,124 @@ class DeezerClient(Client): raise AuthenticationError self.logged_in = True - async def get_metadata(self, item_id: str, media_type: str) -> dict: - pass + async def get_metadata(self, info: dict, media_type: str) -> dict: + request_functions = { + "track": self.client.api.get_track, + "album": self.client.api.get_album, + "playlist": self.client.api.get_playlist, + "artist": self.client.api.get_artist, + } - async def search( - self, query: str, media_type: str, limit: int = 200 - ) -> SearchResult: - pass + get_item = request_functions[media_type] + item = get_item(info["id"]) + if media_type in ("album", "playlist"): + tracks = getattr(self.client.api, f"get_{media_type}_tracks")( + info["id"], limit=-1 + ) + item["tracks"] = tracks["data"] + item["track_total"] = len(tracks["data"]) + elif media_type == "artist": + albums = self.client.api.get_artist_albums(info["id"]) + item["albums"] = albums["data"] - async def get_downloadable(self, item_id: str, quality: int = 2) -> Downloadable: - pass + return item + + async def search(self, query: str, media_type: str, limit: int = 200): + # TODO: use limit parameter + if media_type == "featured": + try: + if query: + search_function = getattr(self.client.api, f"get_editorial_{query}") + else: + search_function = self.client.api.get_editorial_releases + except AttributeError: + raise Exception(f'Invalid editorical selection "{query}"') + else: + try: + search_function = getattr(self.client.api, f"search_{media_type}") + except AttributeError: + raise Exception(f"Invalid media type {media_type}") + + response = search_function(query, limit=limit) # type: ignore + return response + + async def get_downloadable( + self, info: dict, quality: int = 2 + ) -> DeezerDownloadable: + item_id = info["id"] + # TODO: optimize such that all of the ids are requested at once + dl_info: dict = {"quality": quality, "id": item_id} + + track_info = self.client.gw.get_track(item_id) + + dl_info["fallback_id"] = track_info["FALLBACK"]["SNG_ID"] + + quality_map = [ + (9, "MP3_128"), + (3, "MP3_320"), + (1, "FLAC"), + ] + + # available_formats = [ + # "AAC_64", + # "MP3_64", + # "MP3_128", + # "MP3_256", + # "MP3_320", + # "FLAC", + # ] + + _, format_str = quality_map[quality] + + # dl_info["size_to_quality"] = { + # int(track_info.get(f"FILESIZE_{format}")): self._quality_id_from_filetype( + # format + # ) + # for format in available_formats + # } + + token = track_info["TRACK_TOKEN"] + try: + url = self.client.get_track_url(token, format_str) + except deezer.WrongLicense: + raise NonStreamable( + "The requested quality is not available with your subscription. " + "Deezer HiFi is required for quality 2. Otherwise, the maximum " + "quality allowed is 1." + ) + + if url is None: + url = self._get_encrypted_file_url( + item_id, track_info["MD5_ORIGIN"], track_info["MEDIA_VERSION"] + ) + + dl_info["url"] = url + return DeezerDownloadable(dl_info) + + def _get_encrypted_file_url( + self, meta_id: str, track_hash: str, media_version: str + ): + format_number = 1 + + url_bytes = b"\xa4".join( + ( + track_hash.encode(), + str(format_number).encode(), + str(meta_id).encode(), + str(media_version).encode(), + ) + ) + url_hash = hashlib.md5(url_bytes).hexdigest() + info_bytes = bytearray(url_hash.encode()) + info_bytes.extend(b"\xa4") + info_bytes.extend(url_bytes) + info_bytes.extend(b"\xa4") + # Pad the bytes so that len(info_bytes) % 16 == 0 + padding_len = 16 - (len(info_bytes) % 16) + info_bytes.extend(b"." * padding_len) + + path = binascii.hexlify( + AES.new("jo6aey6haid2Teih".encode(), AES.MODE_ECB).encrypt(info_bytes) + ).decode("utf-8") + + return f"https://e-cdns-proxy-{track_hash[0]}.dzcdn.net/mobile/1/{path}" diff --git a/src/downloadable.py b/src/downloadable.py index ed88675..f6ae58e 100644 --- a/src/downloadable.py +++ b/src/downloadable.py @@ -1,25 +1,40 @@ +import asyncio +import functools +import hashlib +import itertools +import json import os +import re import shutil +import subprocess +import tempfile import time from abc import ABC, abstractmethod -from tempfile import gettempdir from typing import Callable, Optional import aiofiles import aiohttp +import m3u8 +from Cryptodome.Cipher import Blowfish + +from . import converter +from .client import NonStreamable def generate_temp_path(url: str): - return os.path.join(gettempdir(), f"__streamrip_{hash(url)}_{time.time()}.download") + return os.path.join( + tempfile.gettempdir(), f"__streamrip_{hash(url)}_{time.time()}.download" + ) class Downloadable(ABC): session: aiohttp.ClientSession url: str + extension: str chunk_size = 1024 _size: Optional[int] = None - async def download(self, path: str, callback: Callable[[], None]): + async def download(self, path: str, callback: Callable[[int], None]): tmp = generate_temp_path(self.url) await self._download(tmp, callback) shutil.move(tmp, path) @@ -29,12 +44,12 @@ class Downloadable(ABC): return self._size async with self.session.head(self.url) as response: response.raise_for_status() - content_length = response.headers["Content-Length"] + content_length = response.headers.get("Content-Length", 0) self._size = int(content_length) return self._size @abstractmethod - async def _download(self, path: str, callback: Callable[[], None]): + async def _download(self, path: str, callback: Callable[[int], None]): raise NotImplemented @@ -44,9 +59,13 @@ class BasicDownloadable(Downloadable): def __init__(self, session: aiohttp.ClientSession, url: str): self.session = session self.url = url + # TODO: verify that this is correct + self.extension = url.split(".")[-1] async def _download(self, path: str, callback: Callable[[int], None]): - async with self.session.get(self.url) as response: + async with self.session.get( + self.url, allow_redirects=True, stream=True + ) as response: response.raise_for_status() async with aiofiles.open(path, "wb") as file: async for chunk in response.content.iter_chunked(self.chunk_size): @@ -56,24 +75,210 @@ class BasicDownloadable(Downloadable): class DeezerDownloadable(Downloadable): - def __init__(self, resp: dict): - self.resp = resp + is_encrypted = re.compile("/m(?:obile|edia)/") + chunk_size = 2048 * 3 - async def _download(self, path: str): - raise NotImplemented + def __init__(self, session: aiohttp.ClientSession, info: dict): + self.session = session + self.url = info["url"] + self.fallback_id = info["fallback_id"] + self.quality = info["quality"] + if self.quality <= 1: + self.extension = "mp3" + else: + self.extension = "flac" + self.id = info["id"] + + async def _download(self, path: str, callback): + async with self.session.get( + self.url, allow_redirects=True, stream=True + ) as resp: + resp.raise_for_status() + self._size = int(resp.headers.get("Content-Length", 0)) + if self._size < 20000 and not self.url.endswith(".jpg"): + try: + info = await resp.json() + try: + # Usually happens with deezloader downloads + raise NonStreamable(f"{info['error']} - {info['message']}") + except KeyError: + raise NonStreamable(info) + + except json.JSONDecodeError: + raise NonStreamable("File not found.") + + async with aiofiles.open(path, "wb") as file: + if self.is_encrypted.search(self.url) is None: + async for chunk in resp.content.iter_chunked(self.chunk_size): + await file.write(chunk) + # typically a bar.update() + callback(self.chunk_size) + else: + blowfish_key = self._generate_blowfish_key(self.id) + async for chunk in resp.content.iter_chunked(self.chunk_size): + if len(chunk) >= 2048: + decrypted_chunk = ( + self._decrypt_chunk(blowfish_key, chunk[:2048]) + + chunk[2048:] + ) + else: + decrypted_chunk = chunk + await file.write(decrypted_chunk) + callback(self.chunk_size) + + @staticmethod + def _decrypt_chunk(key, data): + """Decrypt a chunk of a Deezer stream. + + :param key: + :param data: + """ + return Blowfish.new( + key, + Blowfish.MODE_CBC, + b"\x00\x01\x02\x03\x04\x05\x06\x07", + ).decrypt(data) + + @staticmethod + def _generate_blowfish_key(track_id: str) -> bytes: + """Generate the blowfish key for Deezer downloads. + + :param track_id: + :type track_id: str + """ + SECRET = "g4el58wc0zvf9na1" + md5_hash = hashlib.md5(track_id.encode()).hexdigest() + # good luck :) + return "".join( + chr(functools.reduce(lambda x, y: x ^ y, map(ord, t))) + for t in zip(md5_hash[:16], md5_hash[16:], SECRET) + ).encode() class TidalDownloadable(Downloadable): - def __init__(self, info: dict): - self.info = info + """A wrapper around BasicDownloadable that includes Tidal-specific + error messages.""" - async def _download(self, path: str): - raise NotImplemented + def __init__(self, session: aiohttp.ClientSession, info: dict): + self.session = session + url = info.get("url") + if self.url is None: + if restrictions := info["restrictions"]: + # Turn CamelCase code into a readable sentence + words = re.findall(r"([A-Z][a-z]+)", restrictions[0]["code"]) + raise NonStreamable( + words[0] + " " + " ".join(map(str.lower, words[1:])) + "." + ) + + raise NonStreamable(f"Tidal download: dl_info = {info}") + + assert isinstance(url, str) + self.downloadable = BasicDownloadable(session, url) + + async def _download(self, path: str, callback): + await self.downloadable._download(path, callback) class SoundcloudDownloadable(Downloadable): - def __init__(self, info: dict): - self.info = info + def __init__(self, session, info: dict): + self.session = session + self.file_type = info["type"] + if self.file_type == "mp3": + self.extension = "mp3" + elif self.file_type == "original": + self.extension = "flac" + else: + raise Exception(f"Invalid file type: {self.file_type}") + self.url = info["url"] - async def _download(self, path: str): - raise NotImplemented + async def _download(self, path, callback): + if self.file_type == "mp3": + await self._download_mp3(path, callback) + else: + await self._download_original(path, callback) + + async def _download_original(self, path: str, callback): + downloader = BasicDownloadable(self.session, self.url) + await downloader.download(path, callback) + engine = converter.FLAC(path) + engine.convert(path) + + async def _download_mp3(self, path: str, callback): + async with self.session.get(self.url) as resp: + content = await resp.text("utf-8") + + parsed_m3u = m3u8.loads(content) + self._size = len(parsed_m3u.segments) + tasks = [ + asyncio.create_task(self._download_segment(segment.uri)) + for segment in parsed_m3u.segments + ] + + segment_paths = [] + for coro in asyncio.as_completed(tasks): + segment_paths.append(await coro) + callback(1) + + concat_audio_files(segment_paths, path, "mp3") + + async def _download_segment(self, segment_uri: str) -> str: + tmp = generate_temp_path(segment_uri) + async with self.session.get(segment_uri) as resp: + resp.raise_for_status() + async with aiofiles.open(tmp, "wb") as file: + content = await resp.content.read() + await file.write(content) + return tmp + + +def concat_audio_files(paths: list[str], out: str, ext: str, max_files_open=128): + """Concatenate audio files using FFmpeg. Batched by max files open. + + Recurses log_{max_file_open}(len(paths)) times. + """ + + if shutil.which("ffmpeg") is None: + raise Exception("FFmpeg must be installed.") + + # Base case + if len(paths) == 1: + shutil.move(paths[0], out) + return + + it = iter(paths) + num_batches = len(paths) // max_files_open + ( + 1 if len(paths) % max_files_open != 0 else 0 + ) + tempdir = tempfile.gettempdir() + outpaths = [ + os.path.join( + tempdir, f"__streamrip_ffmpeg_{hash(paths[i*max_files_open])}.{ext}" + ) + for i in range(num_batches) + ] + + for p in outpaths: + try: + os.remove(p) # in case of failure + except FileNotFoundError: + pass + + for i in range(num_batches): + proc = subprocess.run( + ( + "ffmpeg", + "-i", + f"concat:{'|'.join(itertools.islice(it, max_files_open))}", + "-acodec", + "copy", + "-loglevel", + "panic", + outpaths[i], + ), + # capture_output=True, + ) + if proc.returncode != 0: + raise Exception(f"FFMPEG returned with this error: {proc.stderr}") + + # Recurse on remaining batches + concat_audio_files(outpaths, out, ext) diff --git a/src/tidal_client.py b/src/tidal_client.py new file mode 100644 index 0000000..c09e35b --- /dev/null +++ b/src/tidal_client.py @@ -0,0 +1,158 @@ +import base64 +import time + +from .client import Client +from .config import Config + +BASE = "https://api.tidalhifi.com/v1" +AUTH_URL = "https://auth.tidal.com/v1/oauth2" + +CLIENT_ID = base64.b64decode("elU0WEhWVmtjMnREUG80dA==").decode("iso-8859-1") +CLIENT_SECRET = base64.b64decode( + "VkpLaERGcUpQcXZzUFZOQlY2dWtYVEptd2x2YnR0UDd3bE1scmM3MnNlND0=" +).decode("iso-8859-1") + + +class TidalClient(Client): + """TidalClient.""" + + source = "tidal" + max_quality = 3 + + def __init__(self, config: Config): + self.logged_in = False + self.global_config = config + self.config = config.session.tidal + self.session = self.get_session() + self.rate_limiter = self.get_rate_limiter( + config.session.downloads.requests_per_minute + ) + + async def login(self): + c = self.config + if not c.access_token: + raise Exception("Access token not found in config.") + + self.token_expiry = float(c.token_expiry) + self.refresh_token = c.refresh_token + + if self.token_expiry - time.time() < 86400: # 1 day + await self._refresh_access_token() + else: + await self._login_by_access_token(c.access_token, c.user_id) + + self.logged_in = True + + async def _login_by_access_token(self, token: str, user_id: str): + """Login using the access token. + + Used after the initial authorization. + + :param token: access token + :param user_id: To verify that the user is correct + """ + headers = {"authorization": f"Bearer {token}"} # temporary + async with self.session.get( + "https://api.tidal.com/v1/sessions", headers=headers + ) as _resp: + resp = await _resp.json() + + if resp.get("status", 200) != 200: + raise Exception(f"Login failed {resp}") + + if str(resp.get("userId")) != str(user_id): + raise Exception(f"User id mismatch {resp['userId']} v {user_id}") + + c = self.config + c.user_id = resp["userId"] + c.country_code = resp["countryCode"] + c.access_token = token + self._update_authorization_from_config() + + async def _get_login_link(self) -> str: + data = { + "client_id": CLIENT_ID, + "scope": "r_usr+w_usr+w_sub", + } + _resp = await self._api_post(f"{AUTH_URL}/device_authorization", data) + resp = await _resp.json() + + if resp.get("status", 200) != 200: + raise Exception(f"Device authorization failed {resp}") + + device_code = resp["deviceCode"] + return f"https://{device_code}" + + def _update_authorization_from_config(self): + self.session.headers.update( + {"authorization": f"Bearer {self.config.access_token}"} + ) + + async def _get_auth_status(self, device_code) -> tuple[int, dict[str, int | str]]: + """Check if the user has logged in inside the browser. + + returns (status, authentication info) + """ + data = { + "client_id": CLIENT_ID, + "device_code": device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "scope": "r_usr+w_usr+w_sub", + } + _resp = await self._api_post( + f"{AUTH_URL}/token", + data, + (CLIENT_ID, CLIENT_SECRET), + ) + resp = await _resp.json() + + if resp.get("status", 200) != 200: + if resp["status"] == 400 and resp["sub_status"] == 1002: + return 2, {} + else: + return 1, {} + + ret = {} + ret["user_id"] = resp["user"]["userId"] + ret["country_code"] = resp["user"]["countryCode"] + ret["access_token"] = resp["access_token"] + ret["refresh_token"] = resp["refresh_token"] + ret["token_expiry"] = resp["expires_in"] + time.time() + return 0, ret + + async def _refresh_access_token(self): + """Refresh the access token given a refresh token. + + The access token expires in a week, so it must be refreshed. + Requires a refresh token. + """ + data = { + "client_id": CLIENT_ID, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "r_usr+w_usr+w_sub", + } + resp = await self._api_post( + f"{AUTH_URL}/token", + data, + (CLIENT_ID, CLIENT_SECRET), + ) + resp_json = await resp.json() + + if resp_json.get("status", 200) != 200: + raise Exception("Refresh failed") + + c = self.config + c.access_token = resp_json["access_token"] + c.token_expiry = resp_json["expires_in"] + time.time() + self._update_authorization_from_config() + + async def _api_post(self, url, data, auth=None): + """Post to the Tidal API. + + :param url: + :param data: + :param auth: + """ + async with self.session.post(url, data=data, auth=auth, verify=False) as resp: + return resp