Implement search for qobuz and soundcloud

This commit is contained in:
Nathan Thomas 2023-11-26 15:31:30 -08:00
parent ba05436fec
commit 3b237a0339
13 changed files with 476 additions and 60 deletions

View file

@ -30,7 +30,7 @@ class Client(ABC):
raise NotImplementedError
@abstractmethod
async def search(self, query: str, media_type: str, limit: int = 500):
async def search(self, query: str, media_type: str, limit: int = 500) -> list[dict]:
raise NotImplementedError
@abstractmethod

View file

@ -5,7 +5,7 @@ import logging
import re
import time
from collections import OrderedDict
from typing import AsyncGenerator, List, Optional
from typing import List, Optional
import aiohttp
@ -230,34 +230,36 @@ class QobuzClient(Client):
return resp
async def search(
self, query: str, media_type: str, limit: int = 500
) -> AsyncGenerator:
async def search(self, media_type: str, query: str, limit: int = 500) -> list[dict]:
if media_type not in ("artist", "album", "track", "playlist"):
raise Exception(f"{media_type} not available for search on qobuz")
params = {
"query": query,
# "limit": limit,
}
# TODO: move featured, favorites, and playlists into _api_get later
if media_type == "featured":
assert query in QOBUZ_FEATURED_KEYS, f'query "{query}" is invalid.'
params.update({"type": query})
del params["query"]
epoint = "album/getFeatured"
epoint = f"{media_type}/search"
elif query == "user-favorites":
assert query in ("track", "artist", "album")
params.update({"type": f"{media_type}s"})
epoint = "favorite/getUserFavorites"
return await self._paginate(epoint, params, limit=limit)
elif query == "user-playlists":
epoint = "playlist/getUserPlaylists"
async def get_featured(self, query, limit: int = 500) -> list[dict]:
params = {
"type": query,
}
assert query in QOBUZ_FEATURED_KEYS, f'query "{query}" is invalid.'
epoint = "album/getFeatured"
return await self._paginate(epoint, params, limit=limit)
else:
epoint = f"{media_type}/search"
async def get_user_favorites(self, media_type: str, limit: int = 500) -> list[dict]:
assert media_type in ("track", "artist", "album")
params = {"type": f"{media_type}s"}
epoint = "favorite/getUserFavorites"
async for status, resp in self._paginate(epoint, params, limit=limit):
assert status == 200
yield resp
return await self._paginate(epoint, params, limit=limit)
async def get_user_playlists(self, limit: int = 500) -> list[dict]:
epoint = "playlist/getUserPlaylists"
return await self._paginate(epoint, {}, limit=limit)
async def get_downloadable(self, item_id: str, quality: int) -> Downloadable:
assert self.secret is not None and self.logged_in and 1 <= quality <= 4
@ -281,7 +283,7 @@ class QobuzClient(Client):
async def _paginate(
self, epoint: str, params: dict, limit: Optional[int] = None
) -> AsyncGenerator[tuple[int, dict], None]:
) -> list[dict]:
"""Paginate search results.
params:
@ -293,30 +295,41 @@ class QobuzClient(Client):
"""
params.update({"limit": limit or 500})
status, page = await self._api_request(epoint, params)
assert status == 200, status
logger.debug("paginate: initial request made with status %d", status)
# albums, tracks, etc.
key = epoint.split("/")[0] + "s"
items = page.get(key, {})
total = items.get("total", 0) or items.get("items", 0)
total = items.get("total", 0)
if limit is not None and limit < total:
total = limit
logger.debug("paginate: %d total items requested", total)
if not total:
if total == 0:
logger.debug("Nothing found from %s epoint", epoint)
return
return []
limit = int(page.get(key, {}).get("limit", 500))
offset = int(page.get(key, {}).get("offset", 0))
logger.debug("paginate: from response: limit=%d, offset=%d", limit, offset)
params.update({"limit": limit})
yield status, page
pages = []
requests = []
assert status == 200, status
pages.append(page)
while (offset + limit) < total:
offset += limit
params.update({"offset": offset})
yield await self._api_request(epoint, params)
requests.append(self._api_request(epoint, params.copy()))
for status, resp in await asyncio.gather(*requests):
assert status == 200
pages.append(resp)
return pages
async def _get_app_id_and_secrets(self) -> tuple[str, list[str]]:
async with QobuzSpoofer() as spoofer:

View file

@ -170,8 +170,10 @@ class SoundcloudClient(Client):
)
async def search(
self, query: str, media_type: str, limit: int = 50, offset: int = 0
):
self, media_type: str, query: str, limit: int = 50, offset: int = 0
) -> list[dict]:
# TODO: implement pagination
assert media_type in ("track", "playlist")
params = {
"q": query,
"facet": "genre",
@ -182,7 +184,7 @@ class SoundcloudClient(Client):
}
resp, status = await self._api_request(f"search/{media_type}s", params=params)
assert status == 200
return resp
return [resp]
async def _api_request(self, path, params=None, headers=None):
url = f"{BASE}/{path}"

View file

@ -343,7 +343,7 @@ def update_toml_section_from_config(toml_section, config):
class Config:
def __init__(self, path: str, /):
self._path = path
self.path = path
with open(path) as toml_file:
self.file: ConfigData = ConfigData.from_toml(toml_file.read())
@ -354,7 +354,7 @@ class Config:
if not self.file.modified:
return
with open(self._path, "w") as toml_file:
with open(self.path, "w") as toml_file:
self.file.update_toml()
toml_file.write(dumps(self.file.toml))

View file

@ -178,6 +178,8 @@ fallback_source = "deezer"
text_output = true
# Show resolve, download progress bars
progress_bars = true
# The maximum number of search results to show in the interactive menu
max_search_results = 100
[misc]
# Metadata to identify this config file. Do not change.

View file

@ -6,6 +6,8 @@ INF = 9999
class UnlimitedSemaphore:
"""Can be swapped out for a real semaphore when no semaphore is needed."""
async def __aenter__(self):
return self
@ -20,6 +22,15 @@ _global_semaphore: None | tuple[int, asyncio.Semaphore] = None
def global_download_semaphore(
c: DownloadsConfig,
) -> UnlimitedSemaphore | asyncio.Semaphore:
"""A global semaphore that limit the number of total tracks being downloaded
at once.
If concurrency is disabled in the config, the semaphore is set to 1.
Otherwise it's set to `max_connections`.
A negative `max_connections` value means there is no maximum and no semaphore is used.
Since it is global, only one value of `max_connections` is allowed per session.
"""
global _unlimited, _global_semaphore
if c.concurrency:

View file

@ -9,7 +9,7 @@ from ..config import Config
from ..db import Database
from ..filepath_utils import clean_filename
from ..metadata import AlbumMetadata, Covers, TrackMetadata, tag_file
from ..progress import get_progress_callback
from ..progress import add_title, get_progress_callback, remove_title
from .artwork import download_artwork
from .media import Media, Pending
from .semaphore import global_download_semaphore
@ -28,10 +28,13 @@ class Track(Media):
db: Database
# change?
download_path: str = ""
is_single: bool = False
async def preprocess(self):
self._set_download_path()
os.makedirs(self.folder, exist_ok=True)
if self.is_single:
add_title(self.meta.title)
async def download(self):
# TODO: progress bar description
@ -44,6 +47,9 @@ class Track(Media):
await self.downloadable.download(self.download_path, callback)
async def postprocess(self):
if self.is_single:
remove_title(self.meta.title)
await tag_file(self.download_path, self.meta, self.cover_path)
if self.config.session.conversion.enabled:
await self._convert()
@ -146,7 +152,13 @@ class PendingSingle(Pending):
self.client.get_downloadable(self.id, quality),
)
return Track(
meta, downloadable, self.config, folder, embedded_cover_path, self.db
meta,
downloadable,
self.config,
folder,
embedded_cover_path,
self.db,
is_single=True,
)
def _format_folder(self, meta: AlbumMetadata) -> str:

View file

@ -5,6 +5,15 @@ from .artist_metadata import ArtistMetadata
from .covers import Covers
from .label_metadata import LabelMetadata
from .playlist_metadata import PlaylistMetadata
from .search_results import (
AlbumSummary,
ArtistSummary,
LabelSummary,
PlaylistSummary,
SearchResults,
Summary,
TrackSummary,
)
from .tagger import tag_file
from .track_metadata import TrackMetadata
@ -17,4 +26,11 @@ __all__ = [
"Covers",
"tag_file",
"util",
"AlbumSummary",
"ArtistSummary",
"LabelSummary",
"PlaylistSummary",
"Summary",
"TrackSummary",
"SearchResults",
]

View file

@ -0,0 +1,256 @@
import os
import re
import textwrap
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pprint import pprint
class Summary(ABC):
id: str
@abstractmethod
def summarize(self) -> str:
pass
@abstractmethod
def preview(self) -> str:
pass
@classmethod
@abstractmethod
def from_item(cls, item: dict) -> str:
pass
@abstractmethod
def media_type(self) -> str:
pass
def __str__(self):
return self.summarize()
@dataclass(slots=True)
class ArtistSummary(Summary):
id: str
name: str
num_albums: str
def media_type(self):
return "artist"
def summarize(self) -> str:
return self.name
def preview(self) -> str:
return f"{self.num_albums} Albums\n\nID: {self.id}"
@classmethod
def from_item(cls, item: dict):
id = item["id"]
name = (
item.get("name")
or item.get("performer", {}).get("name")
or item.get("artist")
or item.get("artist", {}).get("name")
or (
item.get("publisher_metadata")
and item["publisher_metadata"].get("artist")
)
or "Unknown"
)
num_albums = item.get("albums_count") or "Unknown"
return cls(id, name, num_albums)
@dataclass(slots=True)
class TrackSummary(Summary):
id: str
name: str
artist: str
date_released: str | None
def media_type(self):
return "track"
def summarize(self) -> str:
return f"{self.name} by {self.artist}"
def preview(self) -> str:
return f"Released on:\n{self.date_released}\n\nID: {self.id}"
@classmethod
def from_item(cls, item: dict):
id = item["id"]
name = item.get("title") or item.get("name") or "Unknown"
artist = (
item.get("performer", {}).get("name")
or item.get("artist")
or item.get("artist", {}).get("name")
or (
item.get("publisher_metadata")
and item["publisher_metadata"].get("artist")
)
or "Unknown"
)
date_released = (
item.get("release_date")
or item.get("album", {}).get("release_date_original")
or item.get("display_date")
or item.get("date")
or item.get("year")
or "Unknown"
)
return cls(id, name.strip(), artist, date_released)
@dataclass(slots=True)
class AlbumSummary(Summary):
id: str
name: str
artist: str
num_tracks: str
date_released: str | None
def media_type(self):
return "album"
def summarize(self) -> str:
return f"{self.name} by {self.artist}"
def preview(self) -> str:
return f"Date released:\n{self.date_released}\n\n{self.num_tracks} Tracks\n\nID: {self.id}"
@classmethod
def from_item(cls, item: dict):
id = item["id"]
name = item.get("title") or "Unknown Title"
artist = (
item.get("performer", {}).get("name")
or item.get("artist", {}).get("name")
or item.get("artist")
or (
item.get("publisher_metadata")
and item["publisher_metadata"].get("artist")
)
or "Unknown"
)
num_tracks = item.get("tracks_count", 0) or len(
item.get("tracks", []) or item.get("items", [])
)
date_released = (
item.get("release_date_original")
or item.get("release_date")
or item.get("display_date")
or item.get("date")
or item.get("year")
or "Unknown"
)
# raise Exception(item)
return cls(id, name, artist, str(num_tracks), date_released)
@dataclass(slots=True)
class LabelSummary(Summary):
id: str
name: str
def media_type(self):
return "label"
def summarize(self) -> str:
return str(self)
def preview(self) -> str:
return str(self)
@classmethod
def from_item(cls, item: dict):
id = item["id"]
name = item["name"]
return cls(id, name)
@dataclass(slots=True)
class PlaylistSummary(Summary):
id: str
name: str
creator: str
num_tracks: int
description: str
def summarize(self) -> str:
return f"{self.name} by {self.creator}"
def preview(self) -> str:
wrapped = "\n".join(
textwrap.wrap(self.description, os.get_terminal_size().columns - 4 or 70)
)
return f"{self.num_tracks} tracks\n\nDescription:\n{wrapped}\n\nid:{self.id}"
def media_type(self):
return "playlist"
@classmethod
def from_item(cls, item: dict):
id = item["id"]
name = item.get("name") or item.get("title") or "Unknown"
creator = (
(item.get("publisher_metadata") and item["publisher_metadata"]["artist"])
or item.get("owner", {}).get("name")
or item.get("user", {}).get("username")
or "Unknown"
)
num_tracks = item.get("tracks_count") or -1
description = item.get("description") or "No description"
return cls(id, name, creator, num_tracks, description)
@dataclass(slots=True)
class SearchResults:
results: list[Summary]
@classmethod
def from_pages(cls, source: str, media_type: str, pages: list[dict]):
if media_type == "track":
summary_type = TrackSummary
elif media_type == "album":
summary_type = AlbumSummary
elif media_type == "label":
summary_type = LabelSummary
elif media_type == "artist":
summary_type = ArtistSummary
elif media_type == "playlist":
summary_type = PlaylistSummary
else:
raise Exception(f"invalid media type {media_type}")
results = []
for page in pages:
if source == "soundcloud":
items = page["collection"]
for item in items:
results.append(summary_type.from_item(item))
elif source == "qobuz":
key = media_type + "s"
for item in page[key]["items"]:
results.append(summary_type.from_item(item))
else:
raise NotImplementedError
return cls(results)
def summaries(self) -> list[str]:
return [f"{i+1}. {r.summarize()}" for i, r in enumerate(self.results)]
def get_choices(self, inds: tuple[int, ...] | int):
if isinstance(inds, int):
inds = (inds,)
return [self.results[i] for i in inds]
def preview(self, s: str) -> str:
ind = re.match(r"^\d+", s)
assert ind is not None
i = int(ind.group(0))
return self.results[i - 1].preview()

View file

@ -16,9 +16,8 @@ class ProgressManager:
self.progress = Progress(console=console)
self.task_titles = []
self.prefix = Text.assemble(("Downloading ", "bold cyan"), overflow="ellipsis")
self.live = Live(
Group(self.get_title_text(), self.progress), refresh_per_second=10
)
self._text_cache = self.gen_title_text()
self.live = Live(Group(self._text_cache, self.progress), refresh_per_second=10)
def get_callback(self, total: int, desc: str):
if not self.started:
@ -42,17 +41,22 @@ class ProgressManager:
def add_title(self, title: str):
self.task_titles.append(title.strip())
self._text_cache = self.gen_title_text()
def remove_title(self, title: str):
self.task_titles.remove(title.strip())
self._text_cache = self.gen_title_text()
def get_title_text(self) -> Rule:
def gen_title_text(self) -> Rule:
titles = ", ".join(self.task_titles[:3])
if len(self.task_titles) > 3:
titles += "..."
t = self.prefix + Text(titles)
return Rule(t)
def get_title_text(self) -> Rule:
return self._text_cache
@dataclass(slots=True)
class Handle:

View file

@ -116,10 +116,10 @@ def rip(ctx, config_path, folder, no_db, quality, convert, no_progress, verbose)
async def url(ctx, urls):
"""Download content from URLs."""
with ctx.obj["config"] as cfg:
main = Main(cfg)
await main.add_all(urls)
await main.resolve()
await main.rip()
async with Main(cfg) as main:
await main.add_all(urls)
await main.resolve()
await main.rip()
@rip.command()
@ -134,11 +134,11 @@ async def file(ctx, path):
rip file urls.txt
"""
with ctx.obj["config"] as cfg:
main = Main(cfg)
with open(path) as f:
await main.add_all([line for line in f])
await main.resolve()
await main.rip()
async with Main(cfg) as main:
with open(path) as f:
await main.add_all([line for line in f])
await main.resolve()
await main.rip()
@rip.group()
@ -152,7 +152,7 @@ def config():
@click.pass_context
def config_open(ctx, vim):
"""Open the config file in a text editor."""
config_path = ctx.obj["config_path"]
config_path = ctx.obj["config"].path
console.log(f"Opening file at [bold cyan]{config_path}")
if vim:
if shutil.which("nvim") is not None:
@ -168,7 +168,7 @@ def config_open(ctx, vim):
@click.pass_context
def config_reset(ctx, yes):
"""Reset the config file."""
config_path = ctx.obj["config_path"]
config_path = ctx.obj["config"].path
if not yes:
if not Confirm.ask(
f"Are you sure you want to reset the config file at {config_path}?"
@ -181,15 +181,33 @@ def config_reset(ctx, yes):
@rip.command()
@click.argument("query", required=True)
@click.option(
"-f",
"--first",
help="Automatically download the first search result without showing the menu.",
is_flag=True,
)
@click.argument("source", required=True)
@click.argument("media-type", required=True)
@click.argument("query", required=True)
@click.pass_context
@coro
async def search(query, source):
async def search(ctx, first, source, media_type, query):
"""
Search for content using a specific source.
Example:
rip search qobuz album 'rumours'
"""
raise NotImplementedError
with ctx.obj["config"] as cfg:
async with Main(cfg) as main:
if first:
await main.search_take_first(source, media_type, query)
else:
await main.search_interactive(source, media_type, query)
await main.resolve()
await main.rip()
@rip.command()

View file

@ -1,11 +1,13 @@
import asyncio
import logging
import os
from .. import db
from ..client import Client, QobuzClient, SoundcloudClient
from ..config import Config
from ..console import console
from ..media import Media, Pending, remove_artwork_tempdirs
from ..metadata import SearchResults
from ..progress import clear_progress
from .parse_url import parse_url
from .prompter import get_prompter
@ -67,12 +69,17 @@ class Main:
logger.debug("Added url=%s", url)
async def add_all(self, urls: list[str]):
"""Add multiple urls concurrently as pending items."""
parsed = [parse_url(url) for url in urls]
url_w_client = [
(p, await self.get_logged_in_client(p.source))
for p in parsed
if p is not None
]
url_w_client = []
for i, p in enumerate(parsed):
if p is None:
console.print(
f"[red]Found invalid url [cyan]{urls[i]}[/cyan], skipping."
)
continue
url_w_client.append((p, await self.get_logged_in_client(p.source)))
pendings = await asyncio.gather(
*[
url.into_pending(client, self.config, self.database)
@ -100,6 +107,7 @@ class Main:
return client
async def resolve(self):
"""Resolve all currently pending items."""
with console.status("Resolving URLs...", spinner="dots"):
coros = [p.resolve() for p in self.pending]
new_media: list[Media] = await asyncio.gather(*coros)
@ -108,7 +116,81 @@ class Main:
self.pending.clear()
async def rip(self):
"""Download all resolved items."""
await asyncio.gather(*[item.rip() for item in self.media])
async def search_interactive(self, source: str, media_type: str, query: str):
client = await self.get_logged_in_client(source)
with console.status(f"[bold]Searching {source}", spinner="dots"):
pages = await client.search(media_type, query, limit=100)
if len(pages) == 0:
console.print(f"[red]No search results found for query {query}")
return
search_results = SearchResults.from_pages(source, media_type, pages)
if os.name == "nt" or True:
from pick import pick
choices = pick(
search_results.results,
title=(
f"{source.capitalize()} {media_type} search.\n"
"Press SPACE to select, RETURN to download, CTRL-C to exit."
),
multiselect=True,
min_selection_count=1,
)
assert isinstance(choices, list)
await self.add_all(
[f"http://{source}.com/{media_type}/{item.id}" for item, i in choices]
)
else:
from simple_term_menu import TerminalMenu
menu = TerminalMenu(
search_results.summaries(),
preview_command=search_results.preview,
preview_size=0.5,
title=(
f"Results for {media_type} '{query}' from {source.capitalize()}\n"
"SPACE - select, ENTER - download, ESC - exit"
),
cycle_cursor=True,
clear_screen=True,
multi_select=True,
)
chosen_ind = menu.show()
if chosen_ind is None:
console.print("[yellow]No items chosen. Exiting.")
else:
choices = search_results.get_choices(chosen_ind)
await self.add_all(
[
f"http://{source}.com/{item.media_type()}/{item.id}"
for item in choices
]
)
async def search_take_first(self, source: str, media_type: str, query: str):
client = await self.get_logged_in_client(source)
pages = await client.search(media_type, query, limit=1)
if len(pages) == 0:
console.print(f"[red]No search results found for query {query}")
return
search_results = SearchResults.from_pages(source, media_type, pages)
assert len(search_results.results) > 0
first = search_results.results[0]
await self.add(f"http://{source}.com/{first.media_type()}/{first.id}")
async def __aenter__(self):
return self
async def __aexit__(self, *_):
# Ensure all client sessions are closed
for client in self.clients.values():
if hasattr(client, "session"):
await client.session.close()

View file

@ -64,7 +64,7 @@ class QobuzPrompter(CredentialPrompter):
secho("Enter Qobuz password (will not show on screen): ", fg="green", nl=False)
pwd = hashlib.md5(getpass(prompt="").encode("utf-8")).hexdigest()
secho(
f'Credentials saved to config file at "{self.config._path}"',
f'Credentials saved to config file at "{self.config.path}"',
fg="green",
)
c = self.config.session.qobuz
@ -183,7 +183,7 @@ class DeezerPrompter(CredentialPrompter):
cf.arl = c.arl
self.config.file.set_modified()
secho(
f'Credentials saved to config file at "{self.config._path}"',
f'Credentials saved to config file at "{self.config.path}"',
fg="green",
)