Skip to content
Home / Skills / Ai Ml / LLM Integration
AI

LLM Integration

Ai Ml core v1.0.0

LLM Integration

Overview

Large Language Model integration requires careful handling of API calls, token management, error handling, and response processing. This skill covers patterns for reliable LLM integration in production systems.


Key Concepts

LLM Integration Architecture

┌─────────────────────────────────────────────────────────────┐
│              LLM Integration Architecture                    │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  Application Layer:                                         │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  Business Logic → LLM Client → Provider API         │   │
│  └─────────────────────────────────────────────────────┘   │
│                          │                                   │
│  Abstraction Layer:      ▼                                  │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  ┌─────────┐ ┌──────────┐ ┌──────────┐             │   │
│  │  │ OpenAI  │ │ Anthropic│ │ Azure    │ ...         │   │
│  │  │ Adapter │ │ Adapter  │ │ OpenAI   │             │   │
│  │  └────┬────┘ └────┬─────┘ └────┬─────┘             │   │
│  │       └───────────┴────────────┘                    │   │
│  │                   │                                  │   │
│  │          ┌────────▼────────┐                        │   │
│  │          │ Unified Client  │                        │   │
│  │          │ • Rate limiting │                        │   │
│  │          │ • Retry logic   │                        │   │
│  │          │ • Token count   │                        │   │
│  │          │ • Logging       │                        │   │
│  │          └─────────────────┘                        │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                              │
│  Key Considerations:                                        │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  • Token limits and context windows                 │   │
│  │  • Rate limiting and quotas                         │   │
│  │  • Streaming vs batch responses                     │   │
│  │  • Cost management                                  │   │
│  │  • Latency and timeout handling                     │   │
│  │  • Response parsing and validation                  │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Best Practices

1. Abstract Provider Details

Support multiple LLM providers behind a unified interface.

2. Implement Robust Error Handling

Handle rate limits, timeouts, and API errors gracefully.

3. Track Token Usage

Monitor and limit token consumption for cost control.

4. Use Streaming for Long Responses

Improve user experience with streaming responses.

5. Cache Responses When Appropriate

Reduce costs and latency for repeated queries.


Code Examples

Example 1: Unified LLM Client

// Unified interface for LLM providers
public interface LlmClient {
    
    CompletionResponse complete(CompletionRequest request);
    
    Flux<CompletionChunk> streamComplete(CompletionRequest request);
    
    EmbeddingResponse embed(EmbeddingRequest request);
}

@Data
@Builder
public class CompletionRequest {
    private String model;
    private List<Message> messages;
    private Double temperature;
    private Integer maxTokens;
    private List<String> stop;
    private ResponseFormat responseFormat;
    private Map<String, Object> metadata;
}

@Data
public class Message {
    private Role role;
    private String content;
    private List<ContentPart> contentParts;  // For multimodal
    
    public enum Role { SYSTEM, USER, ASSISTANT, FUNCTION }
}

@Data
public class CompletionResponse {
    private String id;
    private String content;
    private Usage usage;
    private FinishReason finishReason;
    private Map<String, Object> metadata;
}

@Data
public class Usage {
    private int promptTokens;
    private int completionTokens;
    private int totalTokens;
}

// OpenAI implementation
@Service
@RequiredArgsConstructor
public class OpenAiClient implements LlmClient {
    
    private final OpenAIService openAIService;
    private final TokenCounter tokenCounter;
    private final MeterRegistry meterRegistry;
    
    @Override
    public CompletionResponse complete(CompletionRequest request) {
        Timer.Sample timer = Timer.start(meterRegistry);
        
        try {
            ChatCompletionRequest openAiRequest = ChatCompletionRequest.builder()
                .model(request.getModel())
                .messages(toOpenAiMessages(request.getMessages()))
                .temperature(request.getTemperature())
                .maxTokens(request.getMaxTokens())
                .stop(request.getStop())
                .build();
            
            ChatCompletionResult result = openAIService.createChatCompletion(openAiRequest);
            
            // Record metrics
            meterRegistry.counter("llm.tokens.prompt", 
                "model", request.getModel())
                .increment(result.getUsage().getPromptTokens());
            
            meterRegistry.counter("llm.tokens.completion",
                "model", request.getModel())
                .increment(result.getUsage().getCompletionTokens());
            
            return mapToResponse(result);
            
        } finally {
            timer.stop(meterRegistry.timer("llm.request.duration",
                "model", request.getModel(),
                "provider", "openai"));
        }
    }
    
    @Override
    public Flux<CompletionChunk> streamComplete(CompletionRequest request) {
        ChatCompletionRequest openAiRequest = ChatCompletionRequest.builder()
            .model(request.getModel())
            .messages(toOpenAiMessages(request.getMessages()))
            .temperature(request.getTemperature())
            .maxTokens(request.getMaxTokens())
            .stream(true)
            .build();
        
        return Flux.create(sink -> {
            openAIService.streamChatCompletion(openAiRequest)
                .doOnNext(chunk -> sink.next(mapToChunk(chunk)))
                .doOnComplete(sink::complete)
                .doOnError(sink::error)
                .subscribe();
        });
    }
}

// Anthropic implementation
@Service
@RequiredArgsConstructor
public class AnthropicClient implements LlmClient {
    
    private final WebClient webClient;
    private final String apiKey;
    
    @Override
    public CompletionResponse complete(CompletionRequest request) {
        AnthropicRequest anthropicRequest = AnthropicRequest.builder()
            .model(mapModel(request.getModel()))
            .maxTokens(request.getMaxTokens())
            .messages(toAnthropicMessages(request.getMessages()))
            .system(extractSystemMessage(request.getMessages()))
            .build();
        
        AnthropicResponse response = webClient.post()
            .uri("/v1/messages")
            .header("x-api-key", apiKey)
            .header("anthropic-version", "2024-01-01")
            .bodyValue(anthropicRequest)
            .retrieve()
            .bodyToMono(AnthropicResponse.class)
            .block();
        
        return mapToResponse(response);
    }
}

Example 2: Retry and Error Handling

@Service
@Slf4j
public class ResilientLlmClient implements LlmClient {
    
    private final LlmClient delegate;
    private final RetryRegistry retryRegistry;
    private final CircuitBreakerRegistry circuitBreakerRegistry;
    
    public ResilientLlmClient(LlmClient delegate) {
        this.delegate = delegate;
        
        RetryConfig retryConfig = RetryConfig.custom()
            .maxAttempts(3)
            .waitDuration(Duration.ofSeconds(1))
            .exponentialBackoffMultiplier(2.0)
            .retryOnException(this::isRetryable)
            .retryOnResult(this::shouldRetry)
            .build();
        
        this.retryRegistry = RetryRegistry.of(retryConfig);
        
        CircuitBreakerConfig cbConfig = CircuitBreakerConfig.custom()
            .failureRateThreshold(50)
            .slowCallRateThreshold(80)
            .slowCallDurationThreshold(Duration.ofSeconds(30))
            .waitDurationInOpenState(Duration.ofMinutes(1))
            .permittedNumberOfCallsInHalfOpenState(3)
            .slidingWindowType(SlidingWindowType.COUNT_BASED)
            .slidingWindowSize(10)
            .build();
        
        this.circuitBreakerRegistry = CircuitBreakerRegistry.of(cbConfig);
    }
    
    @Override
    public CompletionResponse complete(CompletionRequest request) {
        String model = request.getModel();
        Retry retry = retryRegistry.retry("llm-" + model);
        CircuitBreaker cb = circuitBreakerRegistry.circuitBreaker("llm-" + model);
        
        Supplier<CompletionResponse> decoratedSupplier = 
            CircuitBreaker.decorateSupplier(cb,
                Retry.decorateSupplier(retry, 
                    () -> delegate.complete(request)));
        
        try {
            return decoratedSupplier.get();
        } catch (RateLimitException e) {
            log.warn("Rate limit hit, waiting {} seconds", e.getRetryAfterSeconds());
            throw new LlmTemporarilyUnavailableException(
                "Rate limit exceeded", e.getRetryAfterSeconds());
        } catch (Exception e) {
            log.error("LLM request failed after retries", e);
            throw new LlmException("Failed to complete LLM request", e);
        }
    }
    
    private boolean isRetryable(Throwable t) {
        if (t instanceof RateLimitException) return true;
        if (t instanceof TimeoutException) return true;
        if (t instanceof ServerErrorException) return true;
        if (t instanceof WebClientRequestException) return true;  // Network errors
        return false;
    }
    
    private boolean shouldRetry(CompletionResponse response) {
        // Retry if response indicates overload
        return response.getFinishReason() == FinishReason.OVERLOADED;
    }
}

// Rate limit handling
@ControllerAdvice
public class LlmExceptionHandler {
    
    @ExceptionHandler(RateLimitException.class)
    public ResponseEntity<ProblemDetail> handleRateLimit(RateLimitException ex) {
        ProblemDetail problem = ProblemDetail.of(
            HttpStatus.TOO_MANY_REQUESTS,
            "AI service rate limit exceeded. Please try again later."
        );
        
        return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS)
            .header("Retry-After", String.valueOf(ex.getRetryAfterSeconds()))
            .body(problem);
    }
}

Example 3: Streaming Responses

@RestController
@RequestMapping("/api/chat")
public class ChatController {
    
    private final LlmClient llmClient;
    
    @GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> streamChat(
            @RequestParam String message,
            @RequestHeader("X-Session-ID") String sessionId) {
        
        CompletionRequest request = CompletionRequest.builder()
            .model("gpt-4")
            .messages(buildMessages(sessionId, message))
            .maxTokens(1000)
            .build();
        
        AtomicReference<StringBuilder> fullResponse = new AtomicReference<>(new StringBuilder());
        
        return llmClient.streamComplete(request)
            .map(chunk -> {
                fullResponse.get().append(chunk.getContent());
                return ServerSentEvent.<String>builder()
                    .data(chunk.getContent())
                    .build();
            })
            .concatWith(Flux.defer(() -> {
                // Save complete response to history
                saveToHistory(sessionId, fullResponse.get().toString());
                return Flux.just(ServerSentEvent.<String>builder()
                    .event("done")
                    .data("")
                    .build());
            }))
            .onErrorResume(e -> Flux.just(
                ServerSentEvent.<String>builder()
                    .event("error")
                    .data(e.getMessage())
                    .build()
            ));
    }
}

// Client-side streaming handler
@Component
public class StreamingLlmClient {
    
    private final WebClient webClient;
    
    public Flux<String> streamCompletion(String prompt) {
        return webClient.post()
            .uri("/v1/chat/completions")
            .bodyValue(Map.of(
                "model", "gpt-4",
                "messages", List.of(Map.of("role", "user", "content", prompt)),
                "stream", true
            ))
            .retrieve()
            .bodyToFlux(String.class)
            .filter(line -> line.startsWith("data: ") && !line.contains("[DONE]"))
            .map(line -> line.substring(6))
            .mapNotNull(this::extractContent);
    }
    
    private String extractContent(String json) {
        try {
            JsonNode node = objectMapper.readTree(json);
            return node.path("choices").path(0).path("delta").path("content").asText(null);
        } catch (Exception e) {
            return null;
        }
    }
}

Example 4: Token Management

@Service
public class TokenManager {
    
    private final Map<String, Integer> modelContextWindows = Map.of(
        "gpt-4", 8192,
        "gpt-4-32k", 32768,
        "gpt-4-turbo", 128000,
        "claude-3-opus", 200000,
        "claude-3-sonnet", 200000
    );
    
    public int countTokens(String text, String model) {
        // Use tiktoken for accurate counting
        Encoding encoding = Encodings.newDefaultEncodingRegistry()
            .getEncodingForModel(ModelType.fromName(model))
            .orElseGet(() -> Encodings.newDefaultEncodingRegistry()
                .getEncoding(EncodingType.CL100K_BASE));
        
        return encoding.countTokens(text);
    }
    
    public int countMessagesTokens(List<Message> messages, String model) {
        int tokens = 0;
        
        for (Message message : messages) {
            tokens += 4;  // Message overhead
            tokens += countTokens(message.getContent(), model);
            tokens += 1;  // Role token
        }
        
        tokens += 2;  // Priming
        return tokens;
    }
    
    public List<Message> truncateToFit(List<Message> messages, String model, int reserveForCompletion) {
        int maxContext = modelContextWindows.getOrDefault(model, 4096);
        int available = maxContext - reserveForCompletion;
        
        // Always keep system message
        Message systemMessage = messages.stream()
            .filter(m -> m.getRole() == Role.SYSTEM)
            .findFirst()
            .orElse(null);
        
        int systemTokens = systemMessage != null ? 
            countMessagesTokens(List.of(systemMessage), model) : 0;
        
        List<Message> result = new ArrayList<>();
        if (systemMessage != null) {
            result.add(systemMessage);
        }
        
        // Add messages from most recent, truncating older ones
        int usedTokens = systemTokens;
        List<Message> nonSystemMessages = messages.stream()
            .filter(m -> m.getRole() != Role.SYSTEM)
            .toList();
        
        for (int i = nonSystemMessages.size() - 1; i >= 0; i--) {
            Message msg = nonSystemMessages.get(i);
            int msgTokens = countMessagesTokens(List.of(msg), model);
            
            if (usedTokens + msgTokens <= available) {
                result.add(1, msg);  // Insert after system message
                usedTokens += msgTokens;
            } else {
                break;  // Stop adding older messages
            }
        }
        
        return result;
    }
}

// Cost tracking
@Service
public class LlmCostTracker {
    
    private final MeterRegistry meterRegistry;
    
    // Prices per 1K tokens (example rates)
    private final Map<String, CostRate> pricing = Map.of(
        "gpt-4", new CostRate(0.03, 0.06),
        "gpt-4-turbo", new CostRate(0.01, 0.03),
        "gpt-3.5-turbo", new CostRate(0.0005, 0.0015),
        "claude-3-opus", new CostRate(0.015, 0.075),
        "claude-3-sonnet", new CostRate(0.003, 0.015)
    );
    
    public void trackUsage(String model, Usage usage, String userId) {
        CostRate rate = pricing.getOrDefault(model, new CostRate(0.01, 0.03));
        
        double promptCost = (usage.getPromptTokens() / 1000.0) * rate.promptRate();
        double completionCost = (usage.getCompletionTokens() / 1000.0) * rate.completionRate();
        double totalCost = promptCost + completionCost;
        
        meterRegistry.counter("llm.cost.usd",
            "model", model,
            "user_id", userId)
            .increment(totalCost);
        
        log.info("LLM cost: model={}, tokens={}, cost=${}", 
            model, usage.getTotalTokens(), String.format("%.4f", totalCost));
    }
    
    record CostRate(double promptRate, double completionRate) {}
}

Example 5: Response Caching

@Service
public class CachingLlmClient implements LlmClient {
    
    private final LlmClient delegate;
    private final Cache<String, CompletionResponse> cache;
    private final MessageDigest digest;
    
    public CachingLlmClient(LlmClient delegate, CacheManager cacheManager) {
        this.delegate = delegate;
        this.cache = Caffeine.newBuilder()
            .maximumSize(1000)
            .expireAfterWrite(Duration.ofHours(24))
            .recordStats()
            .build();
        this.digest = MessageDigest.getInstance("SHA-256");
    }
    
    @Override
    public CompletionResponse complete(CompletionRequest request) {
        // Only cache deterministic requests (temperature = 0)
        if (request.getTemperature() != null && request.getTemperature() > 0) {
            return delegate.complete(request);
        }
        
        String cacheKey = generateCacheKey(request);
        
        CompletionResponse cached = cache.getIfPresent(cacheKey);
        if (cached != null) {
            log.debug("Cache hit for LLM request");
            meterRegistry.counter("llm.cache.hits").increment();
            return cached.withMetadata(Map.of("cached", true));
        }
        
        meterRegistry.counter("llm.cache.misses").increment();
        
        CompletionResponse response = delegate.complete(request);
        cache.put(cacheKey, response);
        
        return response;
    }
    
    private String generateCacheKey(CompletionRequest request) {
        String content = request.getModel() + ":" +
            request.getMessages().stream()
                .map(m -> m.getRole() + ":" + m.getContent())
                .collect(Collectors.joining("|"));
        
        byte[] hash = digest.digest(content.getBytes(StandardCharsets.UTF_8));
        return Base64.getEncoder().encodeToString(hash);
    }
}

// Semantic caching with embeddings
@Service
public class SemanticCachingLlmClient implements LlmClient {
    
    private final LlmClient delegate;
    private final EmbeddingService embeddingService;
    private final VectorStore vectorStore;
    private final double similarityThreshold = 0.95;
    
    @Override
    public CompletionResponse complete(CompletionRequest request) {
        // Get embedding for the request
        String queryText = extractQueryText(request);
        float[] queryEmbedding = embeddingService.embed(queryText);
        
        // Search for similar cached responses
        List<VectorMatch> matches = vectorStore.similaritySearch(
            queryEmbedding, 
            1, 
            similarityThreshold
        );
        
        if (!matches.isEmpty()) {
            VectorMatch match = matches.get(0);
            log.info("Semantic cache hit with similarity: {}", match.getScore());
            return objectMapper.readValue(match.getMetadata().get("response"), 
                CompletionResponse.class);
        }
        
        // Execute and cache
        CompletionResponse response = delegate.complete(request);
        
        vectorStore.insert(
            UUID.randomUUID().toString(),
            queryEmbedding,
            Map.of(
                "query", queryText,
                "response", objectMapper.writeValueAsString(response),
                "model", request.getModel()
            )
        );
        
        return response;
    }
}

Anti-Patterns

❌ Hardcoding Provider-Specific Code

Use abstractions to support multiple providers.

❌ Ignoring Token Limits

Always validate request size before sending.


References