shadowtube/shadowtube/preprocess.py
2025-02-18 17:46:37 +01:00

228 lines
7.4 KiB
Python

#
# Copyright (c) 2025 Fedir Kovalov.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import json
import math
from typing import List
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import sqlite3
model = None
cache = sqlite3.connect("cache_similarity.db")
cursor = cache.cursor()
def parse_database(filename):
parsed_data = []
with open(filename, "r", encoding="utf-8") as file:
for line in file:
line = line.strip()
if not line:
continue
try:
parsed_data.append(json.loads(line))
except json.JSONDecodeError:
# Handle unquoted values by fixing the line
fixed_line = fix_unquoted_values(line)
parsed_data.append(json.loads(fixed_line))
return parsed_data
def get_embedding(text: str):
global model
if model is None:
model = SentenceTransformer("all-MiniLM-L6-v2", backend="openvino")
return model.encode(text, normalize_embeddings=True)
def compute_similarity(entry_embedding, database_embeddings):
similarities = cosine_similarity([entry_embedding], database_embeddings)[0]
return similarities
def get_cached_similarity_reflexive(entry1, entry2, verbose=False):
response = []
response.extend(get_cached_similarity(entry1, entry2, verbose=verbose))
response.extend(get_cached_similarity(entry2, entry1, verbose=verbose))
if len(response) > 1:
print("WARN: duplicate pairs!")
return response
def get_cached_similarity(entry1, entry2, verbose=False):
query = (
"SELECT factor FROM similarity WHERE id='"
+ (entry1["videoId"] + "/" + entry2["videoId"])
+ "'"
)
if verbose:
print("INFO: querying database: " + query + " (get_cached_similarity)")
cursor.execute(query)
response = cursor.fetchall()
if verbose:
print("INFO: response was: " + response.__str__() + "(get_cached_similarity)")
formatted_response = []
for tup in response:
formatted_response.append(float(tup[0]))
return formatted_response
def set_cached_similarity(ids, similarities: List[float], verbose=False):
pair_insert = []
for i in range(0, len(ids)):
pair_insert.append((ids[i], float(similarities[i])))
# print(pair_insert)
cursor.executemany(
"INSERT INTO similarity VALUES(?, ?)",
pair_insert,
)
cache.commit()
def get_similarity(entry1, entry2, use_cache=True):
similarity = []
if use_cache:
similarity = get_cached_similarity_reflexive(entry1, entry2)
if len(similarity) == 0:
entry1_embedding = get_embedding(
entry1["title"] + " " + entry1["description"]
)
entry2_embedding = get_embedding(
entry2["title"] + " " + entry2["description"]
)
similarity.append(compute_similarity(entry1_embedding, [entry2_embedding]))
set_cached_similarity(
[entry1["videoId"] + "/" + entry2["videoId"]], similarity[0]
)
else:
entry1_embedding = get_embedding(entry1["title"] + " " + entry1["description"])
entry2_embedding = get_embedding(entry2["title"] + " " + entry2["description"])
similarity.append(compute_similarity(entry1_embedding, [entry2_embedding]))
return similarity[0]
def get_global_similarity(entry, database, k=10, use_cache=True, verbose=False):
entry_text = entry["title"] + " " + entry["description"]
similarities: List[float] = []
# Get all embeddings in the database
database_texts = []
text_keys = []
for e in database:
if entry["videoId"] != e["videoId"]:
cached = get_cached_similarity_reflexive(entry, e, verbose=verbose)
if len(cached) == 0:
text_keys.append(entry["videoId"] + "/" + e["videoId"])
if "description" not in e:
print(e["title"])
database_texts.append(e["title"])
else:
database_texts.append(e["title"] + " " + e["description"])
else:
similarities.append(cached[0])
# Compute similarity
# print(len(text_keys))
if len(text_keys) > 0:
entry_embedding = get_embedding(entry_text)
database_embeddings = np.array([get_embedding(text) for text in database_texts])
computed_similarities: List[float] = compute_similarity(
entry_embedding, database_embeddings
)
set_cached_similarity(text_keys, computed_similarities)
similarities.extend(computed_similarities)
# print(similarities)
# Exclude self-similarity
similarities_sorted = np.sort(similarities)[-k:-1]
# Normalize score to [0, 1]
return float(np.mean(similarities_sorted))
def sort_history(history):
"""returns the same database, but with values sorted by the watch time and freshness."""
sorted_history = history
max_time_watched = -math.inf
min_time_watched = math.inf
max_watch_progress = -math.inf
min_watch_progress = math.inf
for entry in history:
if "timeWatched" in entry:
if int(entry["timeWatched"]) > max_time_watched:
max_time_watched = int(entry["timeWatched"])
if int(entry["timeWatched"]) < min_time_watched:
min_time_watched = int(entry["timeWatched"])
if "watchProgress" in entry:
if int(entry["watchProgress"]) > max_watch_progress:
max_watch_progress = int(entry["watchProgress"])
if int(entry["watchProgress"]) < min_watch_progress:
min_watch_progress = int(entry["watchProgress"])
wp_factor = max_watch_progress - min_watch_progress
wp_offset = min_watch_progress
tw_factor = max_time_watched - min_time_watched
tw_offset = min_time_watched
def quality(entry):
q = 0
if "timeWatched" in entry:
q += (entry["timeWatched"] - tw_offset) / tw_factor
else:
q += 0.5
if "watchProgress" in entry:
q += (entry["watchProgress"] - wp_offset) / wp_factor
else:
q += 0.5
# EXPERIMENTAL!!! WILL MAKE COMPUTER EXPLODE!!!
# q += get_similarity(entry, history)
return (2 - q) / 2
for entry in sorted_history:
entry["quality"] = quality(entry)
sorted_history.sort(key=lambda x: x["quality"])
return sorted_history
def fix_unquoted_values(line):
"""Attempts to fix unquoted values by adding quotes around them."""
import re
def replacer(match):
key, value = match.groups()
if not (value.startswith('"') and value.endswith('"')):
value = f'"{value}"' # Add quotes around the value
return f'"{key}":{value}'
fixed_line = re.sub(r'"(\w+)":(\w+)', replacer, line)
return fixed_line