Skip to content
Home / Skills / Ai Ml / Embeddings
AI

Embeddings

Ai Ml core v1.0.0

Embeddings

Overview

Embeddings convert text into dense vector representations that capture semantic meaning. This skill covers embedding models, vector storage, similarity search, and optimization techniques.


Key Concepts

Embedding Pipeline

┌─────────────────────────────────────────────────────────────┐
│                   Embedding Pipeline                         │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  Text Input:                                                │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  "Machine learning is a subset of AI"               │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                   │
│                          ▼                                   │
│  Embedding Model:                                           │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ┌─────────┐    ┌─────────┐    ┌─────────┐         │   │
│  │  │ OpenAI  │    │ Cohere  │    │ Local   │         │   │
│  │  │ text-   │    │ embed   │    │ SBert   │         │   │
│  │  │ embed-3 │    │ v3      │    │         │         │   │
│  │  └─────────┘    └─────────┘    └─────────┘         │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                   │
│                          ▼                                   │
│  Vector Output:                                             │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  [0.023, -0.041, 0.087, ..., 0.012]                │   │
│  │  Dimensions: 256 - 3072                             │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                   │
│                          ▼                                   │
│  Vector Store:                                              │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ┌─────────┐  ┌─────────┐  ┌─────────┐             │   │
│  │  │Pinecone │  │ Weaviate│  │  pgvec  │             │   │
│  │  │         │  │         │  │ tor     │             │   │
│  │  └─────────┘  └─────────┘  └─────────┘             │   │
│  │                                                      │   │
│  │  Index Types: HNSW, IVF, Flat                       │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                              │
│  Similarity Metrics:                                        │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  • Cosine Similarity: cos(θ) = A·B / (|A||B|)      │   │
│  │  • Dot Product: A·B                                 │   │
│  │  • Euclidean Distance: √Σ(ai-bi)²                  │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Best Practices

1. Choose Appropriate Model

Match embedding model to use case and performance needs.

2. Normalize Vectors

Ensure consistent similarity calculations.

3. Batch Embedding Requests

Reduce API calls and latency.

4. Cache Embeddings

Avoid recomputing for unchanged content.

5. Monitor Drift

Track embedding distribution changes over time.


Code Examples

Example 1: Embedding Service

public interface EmbeddingService {
    float[] embed(String text);
    List<float[]> embedBatch(List<String> texts);
    int getDimensions();
}

@Service
public class OpenAiEmbeddingService implements EmbeddingService {
    
    private final WebClient webClient;
    private final String model = "text-embedding-3-small";
    private final int dimensions = 1536;
    private final int maxBatchSize = 100;
    
    @Override
    public float[] embed(String text) {
        return embedBatch(List.of(text)).get(0);
    }
    
    @Override
    public List<float[]> embedBatch(List<String> texts) {
        if (texts.isEmpty()) {
            return List.of();
        }
        
        // Split into batches
        List<List<String>> batches = partition(texts, maxBatchSize);
        
        List<float[]> allEmbeddings = new ArrayList<>();
        
        for (List<String> batch : batches) {
            EmbeddingRequest request = new EmbeddingRequest(model, batch);
            
            EmbeddingResponse response = webClient.post()
                .uri("/v1/embeddings")
                .bodyValue(request)
                .retrieve()
                .bodyToMono(EmbeddingResponse.class)
                .block();
            
            // Sort by index to maintain order
            List<float[]> batchEmbeddings = response.getData().stream()
                .sorted(Comparator.comparingInt(EmbeddingData::getIndex))
                .map(EmbeddingData::getEmbedding)
                .toList();
            
            allEmbeddings.addAll(batchEmbeddings);
        }
        
        return allEmbeddings;
    }
    
    @Override
    public int getDimensions() {
        return dimensions;
    }
}

// Local embedding with Sentence Transformers
@Service
public class LocalEmbeddingService implements EmbeddingService {
    
    private final OrtSession session;
    private final Tokenizer tokenizer;
    private final int dimensions = 384;
    
    public LocalEmbeddingService() throws Exception {
        // Load ONNX model
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        this.session = env.createSession("models/all-MiniLM-L6-v2.onnx");
        this.tokenizer = new BertTokenizer("models/tokenizer");
    }
    
    @Override
    public float[] embed(String text) {
        // Tokenize
        TokenizerResult tokens = tokenizer.encode(text, 512);
        
        // Create input tensors
        Map<String, OnnxTensor> inputs = Map.of(
            "input_ids", OnnxTensor.createTensor(env, tokens.getInputIds()),
            "attention_mask", OnnxTensor.createTensor(env, tokens.getAttentionMask())
        );
        
        // Run inference
        OrtSession.Result result = session.run(inputs);
        
        // Mean pooling
        float[][] lastHidden = (float[][]) result.get(0).getValue();
        return meanPool(lastHidden, tokens.getAttentionMask());
    }
    
    private float[] meanPool(float[][] hidden, long[] mask) {
        float[] result = new float[dimensions];
        int count = 0;
        
        for (int i = 0; i < hidden.length; i++) {
            if (mask[i] == 1) {
                for (int j = 0; j < dimensions; j++) {
                    result[j] += hidden[i][j];
                }
                count++;
            }
        }
        
        for (int j = 0; j < dimensions; j++) {
            result[j] /= count;
        }
        
        return normalize(result);
    }
}

Example 2: Vector Store Integration

public interface VectorStore {
    void insert(String id, float[] embedding, Map<String, Object> metadata);
    void insertBatch(List<VectorRecord> records);
    List<VectorMatch> search(float[] query, int topK, Map<String, Object> filter);
    void delete(String id);
    void deleteByMetadata(Map<String, Object> filter);
}

@Service
public class PgVectorStore implements VectorStore {
    
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingService embeddingService;
    
    @Override
    public void insert(String id, float[] embedding, Map<String, Object> metadata) {
        String vector = toVectorString(embedding);
        String metadataJson = objectMapper.writeValueAsString(metadata);
        
        jdbcTemplate.update("""
            INSERT INTO embeddings (id, embedding, metadata, created_at)
            VALUES (?, ?::vector, ?::jsonb, NOW())
            ON CONFLICT (id) DO UPDATE SET
                embedding = EXCLUDED.embedding,
                metadata = EXCLUDED.metadata,
                updated_at = NOW()
            """, id, vector, metadataJson);
    }
    
    @Override
    public void insertBatch(List<VectorRecord> records) {
        jdbcTemplate.batchUpdate("""
            INSERT INTO embeddings (id, embedding, metadata)
            VALUES (?, ?::vector, ?::jsonb)
            ON CONFLICT (id) DO UPDATE SET
                embedding = EXCLUDED.embedding,
                metadata = EXCLUDED.metadata
            """,
            records,
            100,
            (ps, record) -> {
                ps.setString(1, record.getId());
                ps.setString(2, toVectorString(record.getEmbedding()));
                ps.setString(3, objectMapper.writeValueAsString(record.getMetadata()));
            }
        );
    }
    
    @Override
    public List<VectorMatch> search(float[] query, int topK, Map<String, Object> filter) {
        String vector = toVectorString(query);
        
        StringBuilder sql = new StringBuilder("""
            SELECT id, metadata, 1 - (embedding <=> ?::vector) as similarity
            FROM embeddings
            WHERE 1=1
            """);
        
        List<Object> params = new ArrayList<>();
        params.add(vector);
        
        // Add metadata filters
        if (filter != null && !filter.isEmpty()) {
            for (Map.Entry<String, Object> entry : filter.entrySet()) {
                sql.append(" AND metadata->>'").append(entry.getKey()).append("' = ?");
                params.add(entry.getValue().toString());
            }
        }
        
        sql.append(" ORDER BY embedding <=> ?::vector LIMIT ?");
        params.add(vector);
        params.add(topK);
        
        return jdbcTemplate.query(sql.toString(), params.toArray(), (rs, rowNum) ->
            new VectorMatch(
                rs.getString("id"),
                rs.getDouble("similarity"),
                parseMetadata(rs.getString("metadata"))
            )
        );
    }
    
    private String toVectorString(float[] embedding) {
        return "[" + Arrays.stream(embedding)
            .mapToObj(f -> String.format("%.8f", f))
            .collect(Collectors.joining(",")) + "]";
    }
}

// Pinecone implementation
@Service
public class PineconeVectorStore implements VectorStore {
    
    private final PineconeClient pinecone;
    private final String indexName;
    
    @Override
    public void insertBatch(List<VectorRecord> records) {
        List<Vector> vectors = records.stream()
            .map(r -> Vector.newBuilder()
                .setId(r.getId())
                .addAllValues(toFloatList(r.getEmbedding()))
                .setMetadata(Struct.newBuilder()
                    .putAllFields(toProtobufFields(r.getMetadata()))
                    .build())
                .build())
            .toList();
        
        pinecone.index(indexName).upsert(UpsertRequest.newBuilder()
            .addAllVectors(vectors)
            .setNamespace("default")
            .build());
    }
    
    @Override
    public List<VectorMatch> search(float[] query, int topK, Map<String, Object> filter) {
        QueryRequest.Builder requestBuilder = QueryRequest.newBuilder()
            .addAllVector(toFloatList(query))
            .setTopK(topK)
            .setIncludeMetadata(true)
            .setNamespace("default");
        
        if (filter != null && !filter.isEmpty()) {
            requestBuilder.setFilter(buildFilter(filter));
        }
        
        QueryResponse response = pinecone.index(indexName).query(requestBuilder.build());
        
        return response.getMatchesList().stream()
            .map(match -> new VectorMatch(
                match.getId(),
                match.getScore(),
                fromProtobufFields(match.getMetadata().getFieldsMap())
            ))
            .toList();
    }
}

Example 3: Similarity Functions

@Component
public class SimilarityCalculator {
    
    /**
     * Cosine similarity: measures angle between vectors
     * Range: -1 to 1 (1 = identical, 0 = orthogonal, -1 = opposite)
     */
    public double cosineSimilarity(float[] a, float[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("Vectors must have same dimension");
        }
        
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }
    
    /**
     * Dot product: faster but requires normalized vectors
     */
    public double dotProduct(float[] a, float[] b) {
        double sum = 0.0;
        for (int i = 0; i < a.length; i++) {
            sum += a[i] * b[i];
        }
        return sum;
    }
    
    /**
     * Euclidean distance: measures absolute distance
     * Range: 0 to infinity (lower = more similar)
     */
    public double euclideanDistance(float[] a, float[] b) {
        double sum = 0.0;
        for (int i = 0; i < a.length; i++) {
            double diff = a[i] - b[i];
            sum += diff * diff;
        }
        return Math.sqrt(sum);
    }
    
    /**
     * Normalize vector to unit length
     */
    public float[] normalize(float[] vector) {
        double norm = 0.0;
        for (float v : vector) {
            norm += v * v;
        }
        norm = Math.sqrt(norm);
        
        float[] normalized = new float[vector.length];
        for (int i = 0; i < vector.length; i++) {
            normalized[i] = (float) (vector[i] / norm);
        }
        return normalized;
    }
}

Example 4: Embedding Caching

@Service
public class CachingEmbeddingService implements EmbeddingService {
    
    private final EmbeddingService delegate;
    private final Cache<String, float[]> cache;
    private final MessageDigest digest;
    
    public CachingEmbeddingService(EmbeddingService delegate) {
        this.delegate = delegate;
        this.cache = Caffeine.newBuilder()
            .maximumSize(10_000)
            .expireAfterAccess(Duration.ofDays(7))
            .recordStats()
            .build();
        this.digest = MessageDigest.getInstance("SHA-256");
    }
    
    @Override
    public float[] embed(String text) {
        String key = hash(text);
        
        float[] cached = cache.getIfPresent(key);
        if (cached != null) {
            return cached;
        }
        
        float[] embedding = delegate.embed(text);
        cache.put(key, embedding);
        return embedding;
    }
    
    @Override
    public List<float[]> embedBatch(List<String> texts) {
        // Check cache for each text
        List<String> uncached = new ArrayList<>();
        Map<String, Integer> keyToIndex = new HashMap<>();
        float[][] results = new float[texts.size()][];
        
        for (int i = 0; i < texts.size(); i++) {
            String text = texts.get(i);
            String key = hash(text);
            float[] cached = cache.getIfPresent(key);
            
            if (cached != null) {
                results[i] = cached;
            } else {
                keyToIndex.put(key, i);
                uncached.add(text);
            }
        }
        
        // Embed uncached texts
        if (!uncached.isEmpty()) {
            List<float[]> newEmbeddings = delegate.embedBatch(uncached);
            
            for (int i = 0; i < uncached.size(); i++) {
                String key = hash(uncached.get(i));
                int originalIndex = keyToIndex.get(key);
                results[originalIndex] = newEmbeddings.get(i);
                cache.put(key, newEmbeddings.get(i));
            }
        }
        
        return Arrays.asList(results);
    }
    
    private String hash(String text) {
        byte[] hashBytes = digest.digest(text.getBytes(StandardCharsets.UTF_8));
        return Base64.getEncoder().encodeToString(hashBytes);
    }
    
    public CacheStats getStats() {
        return cache.stats();
    }
}

// Persistent cache with Redis
@Service
public class RedisEmbeddingCache {
    
    private final RedisTemplate<String, byte[]> redis;
    private final String keyPrefix = "emb:";
    
    public void put(String textHash, float[] embedding) {
        byte[] bytes = toBytes(embedding);
        redis.opsForValue().set(keyPrefix + textHash, bytes, Duration.ofDays(30));
    }
    
    public float[] get(String textHash) {
        byte[] bytes = redis.opsForValue().get(keyPrefix + textHash);
        return bytes != null ? fromBytes(bytes) : null;
    }
    
    private byte[] toBytes(float[] embedding) {
        ByteBuffer buffer = ByteBuffer.allocate(embedding.length * 4);
        for (float f : embedding) {
            buffer.putFloat(f);
        }
        return buffer.array();
    }
    
    private float[] fromBytes(byte[] bytes) {
        ByteBuffer buffer = ByteBuffer.wrap(bytes);
        float[] embedding = new float[bytes.length / 4];
        for (int i = 0; i < embedding.length; i++) {
            embedding[i] = buffer.getFloat();
        }
        return embedding;
    }
}

Example 5: Semantic Search Application

@Service
public class SemanticSearchService {
    
    private final EmbeddingService embeddingService;
    private final VectorStore vectorStore;
    
    /**
     * Index documents for semantic search
     */
    public void indexDocuments(List<Document> documents) {
        // Chunk documents
        List<Chunk> allChunks = documents.stream()
            .flatMap(doc -> chunkingService.chunk(doc).stream())
            .toList();
        
        // Batch embed chunks
        List<String> chunkTexts = allChunks.stream()
            .map(Chunk::getContent)
            .toList();
        
        List<float[]> embeddings = embeddingService.embedBatch(chunkTexts);
        
        // Create vector records
        List<VectorRecord> records = IntStream.range(0, allChunks.size())
            .mapToObj(i -> new VectorRecord(
                allChunks.get(i).getId(),
                embeddings.get(i),
                Map.of(
                    "document_id", allChunks.get(i).getDocumentId(),
                    "content", allChunks.get(i).getContent(),
                    "source", allChunks.get(i).getSource()
                )
            ))
            .toList();
        
        // Batch insert
        vectorStore.insertBatch(records);
    }
    
    /**
     * Semantic search with filters
     */
    public List<SearchResult> search(SearchRequest request) {
        // Embed query
        float[] queryEmbedding = embeddingService.embed(request.getQuery());
        
        // Build filter
        Map<String, Object> filter = new HashMap<>();
        if (request.getDocumentType() != null) {
            filter.put("document_type", request.getDocumentType());
        }
        if (request.getDateRange() != null) {
            filter.put("date_gte", request.getDateRange().getStart());
            filter.put("date_lte", request.getDateRange().getEnd());
        }
        
        // Search
        List<VectorMatch> matches = vectorStore.search(
            queryEmbedding,
            request.getLimit(),
            filter
        );
        
        // Convert to search results
        return matches.stream()
            .map(match -> SearchResult.builder()
                .id(match.getId())
                .content(match.getMetadata().get("content").toString())
                .source(match.getMetadata().get("source").toString())
                .score(match.getScore())
                .build())
            .toList();
    }
    
    /**
     * Find similar items
     */
    public List<SearchResult> findSimilar(String itemId, int limit) {
        // Get the item's embedding
        VectorRecord item = vectorStore.getById(itemId);
        if (item == null) {
            throw new NotFoundException("Item not found: " + itemId);
        }
        
        // Search for similar, excluding self
        List<VectorMatch> matches = vectorStore.search(
            item.getEmbedding(),
            limit + 1,
            Map.of()
        );
        
        return matches.stream()
            .filter(m -> !m.getId().equals(itemId))
            .limit(limit)
            .map(this::toSearchResult)
            .toList();
    }
}

Anti-Patterns

❌ Embedding Without Preprocessing

Clean and normalize text before embedding.

❌ Wrong Similarity Metric

Use cosine for semantic similarity with normalized vectors.


References