228 lines
7.4 KiB
Python
228 lines
7.4 KiB
Python
#
|
|
# Copyright (c) 2025 Fedir Kovalov.
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, version 3.
|
|
#
|
|
# This program is distributed in the hope that it will be useful, but
|
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
# General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
#
|
|
|
|
import json
|
|
import math
|
|
from typing import List
|
|
from sentence_transformers import SentenceTransformer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import numpy as np
|
|
import sqlite3
|
|
|
|
model = None
|
|
|
|
cache = sqlite3.connect("cache_similarity.db")
|
|
cursor = cache.cursor()
|
|
|
|
|
|
def parse_database(filename):
|
|
parsed_data = []
|
|
|
|
with open(filename, "r", encoding="utf-8") as file:
|
|
for line in file:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
parsed_data.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
# Handle unquoted values by fixing the line
|
|
fixed_line = fix_unquoted_values(line)
|
|
parsed_data.append(json.loads(fixed_line))
|
|
|
|
return parsed_data
|
|
|
|
|
|
def get_embedding(text: str):
|
|
global model
|
|
if model is None:
|
|
model = SentenceTransformer("all-MiniLM-L6-v2", backend="openvino")
|
|
return model.encode(text, normalize_embeddings=True)
|
|
|
|
|
|
def compute_similarity(entry_embedding, database_embeddings):
|
|
similarities = cosine_similarity([entry_embedding], database_embeddings)[0]
|
|
return similarities
|
|
|
|
|
|
def get_cached_similarity_reflexive(entry1, entry2, verbose=False):
|
|
response = []
|
|
response.extend(get_cached_similarity(entry1, entry2, verbose=verbose))
|
|
response.extend(get_cached_similarity(entry2, entry1, verbose=verbose))
|
|
if len(response) > 1:
|
|
print("WARN: duplicate pairs!")
|
|
return response
|
|
|
|
|
|
def get_cached_similarity(entry1, entry2, verbose=False):
|
|
query = (
|
|
"SELECT factor FROM similarity WHERE id='"
|
|
+ (entry1["videoId"] + "/" + entry2["videoId"])
|
|
+ "'"
|
|
)
|
|
if verbose:
|
|
print("INFO: querying database: " + query + " (get_cached_similarity)")
|
|
cursor.execute(query)
|
|
response = cursor.fetchall()
|
|
if verbose:
|
|
print("INFO: response was: " + response.__str__() + "(get_cached_similarity)")
|
|
formatted_response = []
|
|
for tup in response:
|
|
formatted_response.append(float(tup[0]))
|
|
return formatted_response
|
|
|
|
|
|
def set_cached_similarity(ids, similarities: List[float], verbose=False):
|
|
pair_insert = []
|
|
for i in range(0, len(ids)):
|
|
pair_insert.append((ids[i], float(similarities[i])))
|
|
# print(pair_insert)
|
|
cursor.executemany(
|
|
"INSERT INTO similarity VALUES(?, ?)",
|
|
pair_insert,
|
|
)
|
|
cache.commit()
|
|
|
|
|
|
def get_similarity(entry1, entry2, use_cache=True):
|
|
similarity = []
|
|
if use_cache:
|
|
similarity = get_cached_similarity_reflexive(entry1, entry2)
|
|
if len(similarity) == 0:
|
|
entry1_embedding = get_embedding(
|
|
entry1["title"] + " " + entry1["description"]
|
|
)
|
|
entry2_embedding = get_embedding(
|
|
entry2["title"] + " " + entry2["description"]
|
|
)
|
|
similarity.append(compute_similarity(entry1_embedding, [entry2_embedding]))
|
|
set_cached_similarity(
|
|
[entry1["videoId"] + "/" + entry2["videoId"]], similarity[0]
|
|
)
|
|
else:
|
|
entry1_embedding = get_embedding(entry1["title"] + " " + entry1["description"])
|
|
entry2_embedding = get_embedding(entry2["title"] + " " + entry2["description"])
|
|
similarity.append(compute_similarity(entry1_embedding, [entry2_embedding]))
|
|
|
|
return similarity[0]
|
|
|
|
|
|
def get_global_similarity(entry, database, k=10, use_cache=True, verbose=False):
|
|
entry_text = entry["title"] + " " + entry["description"]
|
|
|
|
similarities: List[float] = []
|
|
|
|
# Get all embeddings in the database
|
|
database_texts = []
|
|
text_keys = []
|
|
for e in database:
|
|
if entry["videoId"] != e["videoId"]:
|
|
cached = get_cached_similarity_reflexive(entry, e, verbose=verbose)
|
|
if len(cached) == 0:
|
|
text_keys.append(entry["videoId"] + "/" + e["videoId"])
|
|
if "description" not in e:
|
|
print(e["title"])
|
|
database_texts.append(e["title"])
|
|
else:
|
|
database_texts.append(e["title"] + " " + e["description"])
|
|
else:
|
|
similarities.append(cached[0])
|
|
|
|
# Compute similarity
|
|
# print(len(text_keys))
|
|
if len(text_keys) > 0:
|
|
entry_embedding = get_embedding(entry_text)
|
|
database_embeddings = np.array([get_embedding(text) for text in database_texts])
|
|
computed_similarities: List[float] = compute_similarity(
|
|
entry_embedding, database_embeddings
|
|
)
|
|
set_cached_similarity(text_keys, computed_similarities)
|
|
similarities.extend(computed_similarities)
|
|
|
|
# print(similarities)
|
|
|
|
# Exclude self-similarity
|
|
similarities_sorted = np.sort(similarities)[-k:-1]
|
|
|
|
# Normalize score to [0, 1]
|
|
return float(np.mean(similarities_sorted))
|
|
|
|
|
|
def sort_history(history):
|
|
"""returns the same database, but with values sorted by the watch time and freshness."""
|
|
sorted_history = history
|
|
|
|
max_time_watched = -math.inf
|
|
min_time_watched = math.inf
|
|
max_watch_progress = -math.inf
|
|
min_watch_progress = math.inf
|
|
for entry in history:
|
|
if "timeWatched" in entry:
|
|
if int(entry["timeWatched"]) > max_time_watched:
|
|
max_time_watched = int(entry["timeWatched"])
|
|
if int(entry["timeWatched"]) < min_time_watched:
|
|
min_time_watched = int(entry["timeWatched"])
|
|
if "watchProgress" in entry:
|
|
if int(entry["watchProgress"]) > max_watch_progress:
|
|
max_watch_progress = int(entry["watchProgress"])
|
|
if int(entry["watchProgress"]) < min_watch_progress:
|
|
min_watch_progress = int(entry["watchProgress"])
|
|
|
|
wp_factor = max_watch_progress - min_watch_progress
|
|
wp_offset = min_watch_progress
|
|
|
|
tw_factor = max_time_watched - min_time_watched
|
|
tw_offset = min_time_watched
|
|
|
|
def quality(entry):
|
|
q = 0
|
|
if "timeWatched" in entry:
|
|
q += (entry["timeWatched"] - tw_offset) / tw_factor
|
|
else:
|
|
q += 0.5
|
|
|
|
if "watchProgress" in entry:
|
|
q += (entry["watchProgress"] - wp_offset) / wp_factor
|
|
else:
|
|
q += 0.5
|
|
|
|
# EXPERIMENTAL!!! WILL MAKE COMPUTER EXPLODE!!!
|
|
# q += get_similarity(entry, history)
|
|
|
|
return (2 - q) / 2
|
|
|
|
for entry in sorted_history:
|
|
entry["quality"] = quality(entry)
|
|
|
|
sorted_history.sort(key=lambda x: x["quality"])
|
|
|
|
return sorted_history
|
|
|
|
|
|
def fix_unquoted_values(line):
|
|
"""Attempts to fix unquoted values by adding quotes around them."""
|
|
import re
|
|
|
|
def replacer(match):
|
|
key, value = match.groups()
|
|
if not (value.startswith('"') and value.endswith('"')):
|
|
value = f'"{value}"' # Add quotes around the value
|
|
return f'"{key}":{value}'
|
|
|
|
fixed_line = re.sub(r'"(\w+)":(\w+)', replacer, line)
|
|
return fixed_line
|