Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Error on save PyTorch model defined in another module #332

Closed
ankxyz opened this issue Jul 5, 2022 · 17 comments
Closed

Error on save PyTorch model defined in another module #332

ankxyz opened this issue Jul 5, 2022 · 17 comments
Labels
bug Something isn't working ml-framework ML Framework support plugins Plugins and extensions for MLEM!

Comments

@ankxyz
Copy link

ankxyz commented Jul 5, 2022

I defined custom PyTorch network. If I try to save it using MLEM, it fails.

Environment
OS: Ubuntu Linux 20.04
Python: 3.8.10
Virtual envronment: venv
Python dependencies:

  • numpy==1.22.4
  • mlem==0.2.3
  • pandas==1.4.2
  • scikit-learn==1.1.1
  • torch==1.11.0
  • torchvision==0.12.0
  • tqdm==4.64.0

Code

Full code: https://github.com/ankxyz/mlem-pytorch-demo

src/utils/train.py

from torch import nn

class Network(nn.Module):
    def __init__(self, out_features=2) :
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),        #14
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),        #7
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),  #3
            nn.AdaptiveAvgPool2d((1,1))                              #flatten
        )
        
        self.dnnModel = nn.Sequential(
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, out_features)
        )
        
    def forward(self, x) :
        output = self.network(x)
        output = output.squeeze()
        output = self.dnnModel(output)
        return output

src/stages/train.py

"""
Trains classification model
"""

import argparse
from mlem.api import save
from pathlib import Path
import time
import torch
from tqdm import tqdm
from typing import Text

from src.utils.config import load_config
from src.utils.datasets import ImageFolderWithPaths
from src.utils.loggers import get_logger
from src.utils.train import get_loss_fn, Network
from src.utils.transforms import get_transforms


def train(config_path: Text) -> None:
    """Trains gesture classification model.
    Args:
        config_path(Text): path to config
    """

    config = load_config(config_path)
    logger = get_logger('TRAIN', log_level=config.base.log_level)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f'Device: {device}')

    logger.info('Load datasets')
    data_transform = get_transforms()
    
    # Create dataset objects from folder/images
    raw_data_dir = Path(config.base.raw_data_dir)
    train_dataset = ImageFolderWithPaths(
        root=raw_data_dir / 'train',
        transform=data_transform
    )
    
    # Create data loader
    logger.info('Create Data Loaders')
    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.base.batch_size,
        shuffle=True,
        num_workers=config.base.num_workers
    )

    logger.info('Setup model')

    classes_number = len(train_dataset.classes)
    logger.info(f'Classes number = {classes_number}')
    
    model = Network(out_features=classes_number)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train.learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.train.step_size,
        gamma=config.train.gamma
    )

    logger.info('Fit model')
    epochs = config.train.epochs
    loss_fn = get_loss_fn(device)

    start_time = time.time()
    sample_data = torch.Tensor(0)

    for epoch in tqdm(range(epochs)):

        logger.info(f'Epoch {epoch}/{epochs}')  
        logger.info('-' * 10)
        logger.info('Train')
        
        epoch_loss = 0
        epoch_acc = 0
        for i, (img, label, _) in tqdm(enumerate(trainloader)):
            img = img.to(device)
            sample_data = img
            label = label.to(device)
            
            predict = model(img)
            loss = loss_fn(predict, label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss
            correct_prediction = torch.argmax(predict, 1) == label
            correct_prediction = correct_prediction.sum()
            epoch_acc += correct_prediction / img.shape[0]
            
        epoch_loss = epoch_loss / len(trainloader)
        epoch_acc = epoch_acc / len(trainloader)

        print('Epoch : {}/{},   loss : {:.5f},    acc : {:.5f}'.format(epoch+1, epochs, epoch_loss, epoch_acc))
        
        if epoch_acc > 0.99 and epoch_loss < 0.1 :
            print('early stop')
            break

        time_elapsed = time.time() - start_time
        minutes = time_elapsed // 60
        seconds = time_elapsed % 60
        logger.info(f'Training complete in {minutes:.0f}m {seconds:.0f}s') 

        scheduler.step()

    save(
        obj=model,
        path=f'clf_model',
        sample_data=sample_data,
        description=f'PyTorch classification model'
    )


if __name__ == '__main__':

    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--config', dest='config', required=True)
    args = args_parser.parse_args()

    train(args.config)

Commad

PYTHONPATH=. python src/stages/train.py --config=params.yaml

Error

Error
Traceback (most recent call last):
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/metadata.py", line 34, in get_object_metadata
    return MlemData.from_data(
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/objects.py", line 675, in from_data
    data_type = DataType.create(
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/data_type.py", line 74, in create
    return DataAnalyzer.analyze(obj, **kwargs).bind(obj)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/hooks.py", line 107, in analyze
    return cls._find_hook(obj).process(obj, **kwargs)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/hooks.py", line 135, in _find_hook
    raise HookNotFound(
mlem.core.errors.HookNotFound: No suitable DataHook for object of type "Network". Registered hooks: [<class 'mlem.core.data_type.PrimitiveType'>, <class 'mlem.core.data_type.OrderedCollectionHook'>, <class 'mlem.core.data_type.DictType'>, <class 'mlem.contrib.torch.TorchTensorDataType'>, <class 'mlem.contrib.numpy.NumpyNumberType'>, <class 'mlem.contrib.numpy.NumpyNdarrayType'>]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "src/stages/train.py", line 126, in <module>
    train(args.config)
  File "src/stages/train.py", line 112, in train
    save(
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/metadata.py", line 96, in save
    meta = get_object_metadata(
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/metadata.py", line 38, in get_object_metadata
    return MlemModel.from_obj(
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/objects.py", line 615, in from_obj
    requirements=mt.get_requirements().expanded,
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/contrib/torch.py", line 169, in get_requirements
    return super().get_requirements() + InstallableRequirement.from_module(
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/core/model.py", line 279, in get_requirements
    ) + get_object_requirements(self.model)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/utils/module.py", line 585, in get_object_requirements
    a.dump(obj)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/dill/_dill.py", line 620, in dump
    StockPickler.dump(self, obj)
  File "/usr/lib/python3.8/pickle.py", line 487, in dump
    self.save(obj)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/utils/module.py", line 554, in save
    self.add_requirement(obj)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/utils/module.py", line 548, in add_requirement
    self.add_requirement(parent_package)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/utils/module.py", line 537, in add_requirement
    for local_req in get_local_module_reqs(module):
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/mlem/utils/module.py", line 310, in get_local_module_reqs
    tree = ast.parse(inspect.getsource(mod))
  File "/usr/lib/python3.8/inspect.py", line 997, in getsource
    lines, lnum = getsourcelines(object)
  File "/usr/lib/python3.8/inspect.py", line 979, in getsourcelines
    lines, lnum = findsource(object)
  File "/usr/lib/python3.8/inspect.py", line 780, in findsource
    file = getsourcefile(object)
  File "/usr/lib/python3.8/inspect.py", line 696, in getsourcefile
    filename = getfile(object)
  File "/media/alex/hdd/Dev/progexp/mlem/mlem-pytorch-demo/.venv/lib/python3.8/site-packages/torch/package/package_importer.py", line 625, in patched_getfile
    return _orig_getfile(object)
  File "/usr/lib/python3.8/inspect.py", line 659, in getfile
    raise TypeError('{!r} is a built-in module'.format(object))
TypeError: <module 'src.utils' (namespace)> is a built-in module

At first I assumed the error occures because of the network defined in seprate module. I tried to move the network to the same module where it is used. But I got the same error.


Could you please provide exhaustive example of usage MLEM + Pytorch? It would be great to understand:

  • how to save any (including custom) PyTorch models
  • how to correctly specify sample data in the method save() if my data are images
  • how to call apply for saved model from CLI if dataset is images.

Thanks for advance

@aguschin aguschin added bug Something isn't working plugins Plugins and extensions for MLEM! ml-framework ML Framework support labels Jul 5, 2022
@mike0sv
Copy link
Contributor

mike0sv commented Jul 5, 2022

For now images are not supported, but we are working on this in #310.
As for the error, you may try to update to 2.4 since we got some bugs resolved, but I don't think they were the same as yours. But you should at least try
In the meantime I will try to run you code and see what is wrong. Thanks for the code to reproduce it!

@Ayagoz
Copy link

Ayagoz commented Jul 26, 2022

Hey, what about saving optimizer state_dict?
I found an Error of yaml

'cannot represent an object', <FlatterDict> 

seems to be better to use ruamel.yaml or reduce some way recursive structures in the optimizer.
that issue from empty dict

and what about the order in state dict for the models? Seems now it is saved by alphabet order.

@mike0sv
Copy link
Contributor

mike0sv commented Jul 29, 2022

Is there a reproducible example in your repo?

@Ayagoz
Copy link

Ayagoz commented Aug 9, 2022

Yes, I prepared one example

Traceback (most recent call last):
  File "src/example_for_mlem.py", line 35, in <module>
    main()
  File "src/example_for_mlem.py", line 28, in main
    save(
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/mlem/core/metadata.py", line 99, in save
    meta.dump(path, fs=fs, project=project, index=index, external=external)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/mlem/core/objects.py", line 508, in dump
    self._write_meta(location, index)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/mlem/core/objects.py", line 243, in _write_meta
    safe_dump(self.dict(), f)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/__init__.py", line 269, in safe_dump
    return dump_all([data], stream, Dumper=SafeDumper, **kwds)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/__init__.py", line 241, in dump_all
    dumper.represent(data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 27, in represent
    node = self.represent_data(data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 58, in represent_data
    node = self.yaml_representers[None](self, data)
  File "/home/ayagoz/miniconda3/envs/dvc/lib/python3.8/site-packages/yaml/representer.py", line 231, in represent_undefined
    raise RepresenterError("cannot represent an object", data)
yaml.representer.RepresenterError: ('cannot represent an object', <FlatterDict id=139695677648848 {}>")

Are there any examples with Pytorch models? Or can you describe whether the module must have methods predict and predict_proba to run fast API? https://example-mlem-get-started-app.herokuapp.com/docs#

@mike0sv
Copy link
Contributor

mike0sv commented Aug 10, 2022

They are automatically generated from appropriate model methods.
Some examples can be found in tests, but they might not be very clear :(

def test_torch_custom_net(second_tensor, tmpdir):

I will take a look at your example and try to fix it though

@mike0sv
Copy link
Contributor

mike0sv commented Aug 10, 2022

Ok I see now what is going on. You try to save state with mlem, and state is a dict which mlem assumes as data type, not model type. So I guess we dont support saving pytorch models as state dict :(
Can you explain why this approach may be more preferable than just save optimized model directly?

@mike0sv
Copy link
Contributor

mike0sv commented Aug 10, 2022

Btw, with fix above saving and loading as dict does work

@mike0sv
Copy link
Contributor

mike0sv commented Aug 10, 2022

But you still wont be able to deploy it as only MlemModel can be deployed, not MlemData

@Ayagoz
Copy link

Ayagoz commented Aug 26, 2022

Ok I see now what is going on. You try to save state with mlem, and state is a dict which mlem assumes as data type, not model type. So I guess we dont support saving pytorch models as state dict :( Can you explain why this approach may be more preferable than just save optimized model directly?

Yes, sure. You need it when you randomly stopped training, so you would like to start with the same optimizer state. Most optimisers have a training parameter. So, if you want to continue exactly from the same point, you need the optimizer's state too.

@vitalwarley
Copy link

I think my use case is similar so I will post it here instead of a new issue.

I have the following MRE

import mlem
import torch
# pip install yolov5
from yolov5 import train, load 

train.run(imgsz=640, data='coco128.yaml', epochs=1)
model = load('runs/train/exp/weights/best.pt')
data = torch.randn(1, 3, 640, 640)
mlem.api.save(model, 'best', sample_data=data)

and this gives me

Traceback (most recent call last):
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/metadata.py", line 34, in get_object_metadata
    return MlemData.from_data(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/objects.py", line 677, in from_data
    data_type = DataType.create(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/data_type.py", line 75, in create
    return DataAnalyzer.analyze(obj, is_dynamic=is_dynamic, **kwargs).bind(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/hooks.py", line 107, in analyze
    return cls._find_hook(obj).process(obj, **kwargs)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/hooks.py", line 135, in _find_hook
    raise HookNotFound(
mlem.core.errors.HookNotFound: No suitable DataHook for object of type "AutoShape". Registered hooks: [<class 'mlem.core.data_type.PrimitiveType'>, <class 'mlem.core.data_type.OrderedCollectionHook'>, <class 'mlem.core.data_type.DictTypeHook'>, <class 'mlem.contrib.numpy.NumpyNumberType'>, <class 'mlem.contrib.numpy.NumpyNdarrayType'>, <class 'mlem.contrib.torch.TorchTensorDataType'>, <class 'mlem.contrib.pandas.SeriesType'>, <class 'mlem.contrib.pandas.DataFrameType'>]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/warley/dev/tcc-ocr-demo/mlem_yolov5_example.py", line 9, in <module>
    mlem.api.save(model, 'best', sample_data=data)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/metadata.py", line 96, in save
    meta = get_object_metadata(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/metadata.py", line 38, in get_object_metadata
    return MlemModel.from_obj(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/objects.py", line 617, in from_obj
    requirements=mt.get_requirements().expanded,
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/contrib/torch.py", line 188, in get_requirements
    return super().get_requirements() + InstallableRequirement.from_module(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/model.py", line 310, in get_requirements
    ) + get_object_requirements(self.model)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 587, in get_object_requirements
    a.dump(obj)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/dill/_dill.py", line 620, in dump
    StockPickler.dump(self, obj)
  File "/usr/lib/python3.10/pickle.py", line 487, in dump
    self.save(obj)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 556, in save
    self.add_requirement(obj)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  [Previous line repeated 1 more time]
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 550, in add_requirement
    self.add_requirement(parent_package)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 542, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 539, in add_requirement
    for local_req in get_local_module_reqs(module):
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 310, in get_local_module_reqs
    tree = ast.parse(inspect.getsource(mod))
  File "/usr/lib/python3.10/inspect.py", line 1147, in getsource
    lines, lnum = getsourcelines(object)
  File "/usr/lib/python3.10/inspect.py", line 1129, in getsourcelines
    lines, lnum = findsource(object)
  File "/usr/lib/python3.10/inspect.py", line 958, in findsource
    raise OSError('could not get source code')
OSError: could not get source code

Any tips on how to proceed? Please tell me if you need more info. Thanks!


🐶 MLEM Version: 0.2.8

@mike0sv
Copy link
Contributor

mike0sv commented Oct 16, 2022

Can you check using MLEM from main branch?

@vitalwarley
Copy link

@mike0sv, sure.

🐶 MLEM Version: 0.2.9.dev9+g86105ee

Unfortunately, it didn't change much

Traceback (most recent call last):
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/metadata.py", line 33, in get_object_metadata
    return MlemData.from_data(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/objects.py", line 668, in from_data
    data_type = DataType.create(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/data_type.py", line 75, in create
    return DataAnalyzer.analyze(obj, is_dynamic=is_dynamic, **kwargs).bind(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/hooks.py", line 107, in analyze
    return cls._find_hook(obj).process(obj, **kwargs)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/hooks.py", line 135, in _find_hook
    raise HookNotFound(
mlem.core.errors.HookNotFound: No suitable DataHook for object of type "AutoShape". Registered hooks: [<class 'mlem.core.data_type.PrimitiveType'>, <class 'mlem.core.data_type.OrderedCollectionHook'>, <class 'mlem.core.data_type.DictTypeHook'>, <class 'mlem.contrib.numpy.NumpyNumberType'>, <class 'mlem.contrib.numpy.NumpyNdarrayType'>, <class 'mlem.contrib.torch.TorchTensorDataType'>, <class 'mlem.contrib.pandas.SeriesType'>, <class 'mlem.contrib.pandas.DataFrameType'>]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/warley/dev/tcc-ocr-demo/mlem_yolov5_example.py", line 8, in <module>
    mlem.api.save(model, 'best', sample_data=data)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/metadata.py", line 73, in save
    meta = get_object_metadata(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/metadata.py", line 38, in get_object_metadata
    return MlemModel.from_obj(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/objects.py", line 612, in from_obj
    requirements=mt.get_requirements().expanded,
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/contrib/torch.py", line 194, in get_requirements
    return super().get_requirements() + InstallableRequirement.from_module(
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/core/model.py", line 310, in get_requirements
    ) + get_object_requirements(self.model)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 598, in get_object_requirements
    a.dump(obj)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/dill/_dill.py", line 620, in dump
    StockPickler.dump(self, obj)
  File "/usr/lib/python3.10/pickle.py", line 487, in dump
    self.save(obj)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 567, in save
    self.add_requirement(obj)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 551, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 551, in add_requirement
    self.add_requirement(local_req)
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 551, in add_requirement
    self.add_requirement(local_req)
  [Previous line repeated 2 more times]
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 548, in add_requirement
    for local_req in get_local_module_reqs(module):
  File "/home/warley/.virtualenvs/tcc-ocr-demo/lib/python3.10/site-packages/mlem/utils/module.py", line 310, in get_local_module_reqs
    tree = ast.parse(inspect.getsource(mod))
  File "/usr/lib/python3.10/inspect.py", line 1147, in getsource
    lines, lnum = getsourcelines(object)
  File "/usr/lib/python3.10/inspect.py", line 1129, in getsourcelines
    lines, lnum = findsource(object)
  File "/usr/lib/python3.10/inspect.py", line 958, in findsource
    raise OSError('could not get source code')
OSError: could not get source code

@mike0sv
Copy link
Contributor

mike0sv commented Oct 27, 2022

Hmm, "works on my machine" TM. Can you please also try #397 branch?

Btw my env:

torch==1.12.1
yolov5==6.2.2

Can you share yours?

@mike0sv
Copy link
Contributor

mike0sv commented Oct 27, 2022

Actually, try #453 with MLEM_DEBUG=true env. It should print out the module that gives the trouble

@mike0sv
Copy link
Contributor

mike0sv commented Oct 27, 2022

It all got merged into main, so please re-install from it :)
Sorry for the noise and expect a new comment with "now it's in 0.3.0 release, try with pip install -U mlem now" soon :)

@vitalwarley
Copy link

@mike0sv, I upgraded to v0.3.0 and now it works! Thank you for your help!

@aguschin
Copy link
Contributor

aguschin commented Nov 8, 2022

Closing since it was fixed!

@aguschin aguschin closed this as completed Nov 8, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working ml-framework ML Framework support plugins Plugins and extensions for MLEM!
Projects
None yet
Development

No branches or pull requests

5 participants